op_function_common.cc 33.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/pybind/op_function_common.h"

17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>

#include <memory>
#include <string>
#include <vector>

#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
32
#include "paddle/fluid/operators/ops_extra_info.h"
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
#include "paddle/fluid/pybind/imperative.h"

namespace py = pybind11;
namespace paddle {
namespace pybind {

class OpAttrTypeMap {
 public:
  static OpAttrTypeMap& Instance() {
    static OpAttrTypeMap g_op_attr_type_map;
    return g_op_attr_type_map;
  }

  std::unordered_map<
      std::string,
      std::unordered_map<std::string, paddle::framework::proto::AttrType>>&
  Map() {
    return ops_attrtype_map_;
  }

 private:
  OpAttrTypeMap() = default;
  std::unordered_map<
      std::string,
      std::unordered_map<std::string, paddle::framework::proto::AttrType>>
      ops_attrtype_map_;
};

extern PyTypeObject* g_varbase_pytype;
extern PyTypeObject* g_vartype_pytype;
extern PyTypeObject* g_blockdesc_pytype;
64
extern PyTypeObject* p_tensor_type;
65 66 67 68 69 70

bool PyObject_CheckBool(PyObject** obj) { return PyBool_Check(*obj); }

bool PyObject_CheckLongOrToLong(PyObject** obj) {
  if ((PyLong_Check(*obj) && !PyBool_Check(*obj)) ||
      PyObject_IsInstance(*obj, (PyObject*)g_vartype_pytype) ||  // NOLINT
71 72
      PyObject_IsInstance(*obj, (PyObject*)g_varbase_pytype) ||  // NOLINT
      PyObject_IsInstance(*obj, (PyObject*)p_tensor_type)) {     // NOLINT
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    return true;
  }

  if (std::string(((PyTypeObject*)(*obj)->ob_type)->tp_name)  // NOLINT
          .find("numpy") != std::string::npos) {
    auto to = PyNumber_Long(*obj);
    if (to) {
      *obj = to;
      return true;
    }
  }

  return false;
}

bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
  // sometimes users provide PyLong or numpy.int64 but attr is float
  if (PyFloat_Check(*obj) || PyLong_Check(*obj) ||
91 92
      PyObject_IsInstance(*obj, (PyObject*)g_varbase_pytype) ||  // NOLINT
      PyObject_IsInstance(*obj, (PyObject*)p_tensor_type)) {     // NOLINT
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    return true;
  }
  if (std::string(((PyTypeObject*)(*obj)->ob_type)->tp_name)  // NOLINT
          .find("numpy") != std::string::npos) {
    auto to = PyNumber_Float(*obj);
    if (to) {
      *obj = to;
      return true;
    }
  }
  return false;
}

bool PyObject_CheckString(PyObject* obj) { return PyUnicode_Check(obj); }

108 109
bool CastPyArg2Boolean(PyObject* obj,
                       const std::string& op_type,
110
                       ssize_t arg_pos) {
111
  if (obj == Py_None) {
112 113
    return false;  // To be compatible with QA integration testing. Some
                   // test case pass in None.
114
  } else if (obj == Py_True) {
115
    return true;
116
  } else if (obj == Py_False) {
117
    return false;
118 119 120 121
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "bool, but got %s",
122 123
        op_type,
        arg_pos + 1,
124 125
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
126 127 128 129 130 131

  return false;
}

void CastPyArg2AttrBoolean(PyObject* obj,
                           paddle::framework::AttributeMap& attrs,  // NOLINT
132 133
                           const std::string& key,
                           const std::string& op_type,
134 135 136 137 138 139 140 141 142 143 144
                           ssize_t arg_pos) {
  attrs[key] = CastPyArg2Boolean(obj, op_type, arg_pos);
}

int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos) {
  if (PyObject_CheckLongOrToLong(&obj)) {
    return (int)PyLong_AsLong(obj);  // NOLINT
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "int, but got %s",
145 146
        op_type,
        arg_pos + 1,
147 148 149 150
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }

  return 0;
151 152 153 154
}

void CastPyArg2AttrInt(PyObject* obj,
                       paddle::framework::AttributeMap& attrs,  // NOLINT
155 156
                       const std::string& key,
                       const std::string& op_type,
157
                       ssize_t arg_pos) {
158 159 160
  attrs[key] = CastPyArg2Int(obj, op_type, arg_pos);
}

161 162
int64_t CastPyArg2Long(PyObject* obj,
                       const std::string& op_type,
163
                       ssize_t arg_pos) {
164
  if (PyObject_CheckLongOrToLong(&obj)) {
165
    return (int64_t)PyLong_AsLongLong(obj);  // NOLINT
166 167 168
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
169
        "long, but got %s",
170 171
        op_type,
        arg_pos + 1,
172 173
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
174 175

  return 0;
176 177 178 179
}

void CastPyArg2AttrLong(PyObject* obj,
                        paddle::framework::AttributeMap& attrs,  // NOLINT
180 181
                        const std::string& key,
                        const std::string& op_type,
182
                        ssize_t arg_pos) {
183 184 185
  attrs[key] = CastPyArg2Long(obj, op_type, arg_pos);
}

186 187
float CastPyArg2Float(PyObject* obj,
                      const std::string& op_type,
188
                      ssize_t arg_pos) {
189 190 191
  return static_cast<float>(CastPyArg2Double(obj, op_type, arg_pos));
}

192 193 194 195 196 197 198 199
void CastPyArg2AttrFloat(PyObject* obj,
                         paddle::framework::AttributeMap& attrs,  // NOLINT
                         const std::string& key,
                         const std::string& op_type,
                         ssize_t arg_pos) {
  attrs[key] = CastPyArg2Float(obj, op_type, arg_pos);
}

200 201
double CastPyArg2Double(PyObject* obj,
                        const std::string& op_type,
202
                        ssize_t arg_pos) {
203
  if (PyObject_CheckFloatOrToFloat(&obj)) {
204
    return PyFloat_AsDouble(obj);  // NOLINT
205 206 207
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
208
        "double, but got %s",
209 210
        op_type,
        arg_pos + 1,
211 212
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
213 214

  return 0.0;
215 216
}

217 218 219 220 221 222
void CastPyArg2AttrDouble(PyObject* obj,
                          paddle::framework::AttributeMap& attrs,  // NOLINT
                          const std::string& key,
                          const std::string& op_type,
                          ssize_t arg_pos) {
  attrs[key] = CastPyArg2Double(obj, op_type, arg_pos);
223 224
}

225 226
std::string CastPyArg2String(PyObject* obj,
                             const std::string& op_type,
227
                             ssize_t arg_pos) {
228 229 230 231
  if (PyObject_CheckString(obj)) {
    Py_ssize_t size;
    const char* data;
    data = PyUnicode_AsUTF8AndSize(obj, &size);
232
    return std::string(data, (size_t)size);  // NOLINT
233 234 235 236
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "str, but got %s",
237 238
        op_type,
        arg_pos + 1,
239 240
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
241 242

  return "";
243 244
}

245 246
void CastPyArg2AttrString(PyObject* obj,
                          paddle::framework::AttributeMap& attrs,  // NOLINT
247 248
                          const std::string& key,
                          const std::string& op_type,
249 250 251 252
                          ssize_t arg_pos) {
  attrs[key] = CastPyArg2String(obj, op_type, arg_pos);
}

253 254
std::vector<bool> CastPyArg2Booleans(PyObject* obj,
                                     const std::string& op_type,
255 256
                                     ssize_t arg_pos) {
  std::vector<bool> value;
257 258 259 260 261 262 263 264 265 266 267
  if (PyList_Check(obj)) {
    Py_ssize_t len = PyList_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyList_GetItem(obj, i);
      if (PyObject_CheckBool(&item)) {
        value.emplace_back(PyLong_AsLong(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of bool, but got %s at pos %d",
268 269
            op_type,
            arg_pos + 1,
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else if (PyTuple_Check(obj)) {
    Py_ssize_t len = PyTuple_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyTuple_GetItem(obj, i);
      if (PyObject_CheckBool(&item)) {
        value.emplace_back(PyLong_AsLong(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of bool, but got %s at pos %d",
285 286
            op_type,
            arg_pos + 1,
287 288 289 290 291 292 293 294
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "list or tuple, but got %s",
295 296
        op_type,
        arg_pos + 1,
297 298
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
299 300

  return value;
301 302
}

303 304
void CastPyArg2AttrBooleans(PyObject* obj,
                            paddle::framework::AttributeMap& attrs,  // NOLINT
305 306
                            const std::string& key,
                            const std::string& op_type,
307 308 309 310
                            ssize_t arg_pos) {
  attrs[key] = CastPyArg2Booleans(obj, op_type, arg_pos);
}

311 312
std::vector<int> CastPyArg2Ints(PyObject* obj,
                                const std::string& op_type,
313 314
                                ssize_t arg_pos) {
  std::vector<int> value;
315 316
  if (PyList_Check(obj)) {
    Py_ssize_t len = PyList_Size(obj);
Z
zyfncg 已提交
317
    value.reserve(len);
318 319 320 321 322 323 324 325 326
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyList_GetItem(obj, i);
      if (PyObject_CheckLongOrToLong(&item)) {
        value.emplace_back(PyLong_AsLong(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of int, but got %s at pos %d",
327 328
            op_type,
            arg_pos + 1,
329 330 331 332 333 334
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else if (PyTuple_Check(obj)) {
    Py_ssize_t len = PyTuple_Size(obj);
Z
zyfncg 已提交
335
    value.reserve(len);
336 337 338 339 340 341 342 343 344
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyTuple_GetItem(obj, i);
      if (PyObject_CheckLongOrToLong(&item)) {
        value.emplace_back(PyLong_AsLong(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of int, but got %s at pos %d",
345 346
            op_type,
            arg_pos + 1,
347 348 349 350 351 352
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else if (PySequence_Check(obj)) {
    Py_ssize_t len = PySequence_Size(obj);
Z
zyfncg 已提交
353
    value.reserve(len);
354 355 356 357 358 359 360 361 362
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PySequence_GetItem(obj, i);
      if (PyObject_CheckLongOrToLong(&item)) {
        value.emplace_back(PyLong_AsLong(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of int, but got %s at pos %d",
363 364
            op_type,
            arg_pos + 1,
365 366 367 368 369 370 371 372
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "list or tuple, but got %s",
373 374
        op_type,
        arg_pos + 1,
375 376
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
377 378

  return value;
379 380
}

381 382
void CastPyArg2AttrInts(PyObject* obj,
                        paddle::framework::AttributeMap& attrs,  // NOLINT
383 384
                        const std::string& key,
                        const std::string& op_type,
385 386 387 388
                        ssize_t arg_pos) {
  attrs[key] = CastPyArg2Ints(obj, op_type, arg_pos);
}

389 390
std::vector<int64_t> CastPyArg2Longs(PyObject* obj,
                                     const std::string& op_type,
391 392
                                     ssize_t arg_pos) {
  std::vector<int64_t> value;
393 394 395 396 397 398 399 400 401 402 403
  if (PyList_Check(obj)) {
    Py_ssize_t len = PyList_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyList_GetItem(obj, i);
      if (PyObject_CheckLongOrToLong(&item)) {
        value.emplace_back(PyLong_AsLong(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of int, but got %s at pos %d",
404 405
            op_type,
            arg_pos + 1,
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else if (PyTuple_Check(obj)) {
    Py_ssize_t len = PyTuple_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyTuple_GetItem(obj, i);
      if (PyObject_CheckLongOrToLong(&item)) {
        value.emplace_back(PyLong_AsLong(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of int, but got %s at pos %d",
421 422
            op_type,
            arg_pos + 1,
423 424 425 426 427 428 429 430 431 432 433 434 435 436 437
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else if (PySequence_Check(obj)) {
    Py_ssize_t len = PySequence_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PySequence_GetItem(obj, i);
      if (PyObject_CheckLongOrToLong(&item)) {
        value.emplace_back(PyLong_AsLong(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of int, but got %s at pos %d",
438 439
            op_type,
            arg_pos + 1,
440 441 442 443
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
444
  } else if ((PyObject*)obj != Py_None) {  // NOLINT
445 446 447
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "list or tuple, but got %s",
448 449
        op_type,
        arg_pos + 1,
450 451
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
452 453

  return value;
454 455
}

456 457
void CastPyArg2AttrLongs(PyObject* obj,
                         paddle::framework::AttributeMap& attrs,  // NOLINT
458 459
                         const std::string& key,
                         const std::string& op_type,
460 461 462 463
                         ssize_t arg_pos) {
  attrs[key] = CastPyArg2Longs(obj, op_type, arg_pos);
}

464 465
std::vector<float> CastPyArg2Floats(PyObject* obj,
                                    const std::string& op_type,
466 467
                                    ssize_t arg_pos) {
  std::vector<float> value;
468 469 470 471 472 473 474 475 476 477 478
  if (PyList_Check(obj)) {
    Py_ssize_t len = PyList_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyList_GetItem(obj, i);
      if (PyObject_CheckFloatOrToFloat(&item)) {
        value.emplace_back(PyFloat_AsDouble(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of float, but got %s at pos %d",
479 480
            op_type,
            arg_pos + 1,
481 482 483 484 485 486 487 488 489 490 491 492 493 494 495
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else if (PyTuple_Check(obj)) {
    Py_ssize_t len = PyTuple_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyTuple_GetItem(obj, i);
      if (PyObject_CheckFloatOrToFloat(&item)) {
        value.emplace_back(PyFloat_AsDouble(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of float, but got %s at pos %d",
496 497
            op_type,
            arg_pos + 1,
498 499 500 501 502 503 504 505 506 507 508 509 510 511 512
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else if (PySequence_Check(obj)) {
    Py_ssize_t len = PySequence_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PySequence_GetItem(obj, i);
      if (PyObject_CheckFloatOrToFloat(&item)) {
        value.emplace_back(PyFloat_AsDouble(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of float, but got %s at pos %d",
513 514
            op_type,
            arg_pos + 1,
515 516 517 518 519 520 521 522
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "list or tuple, but got %s",
523 524
        op_type,
        arg_pos + 1,
525 526
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
527 528

  return value;
529 530
}

531 532
void CastPyArg2AttrFloats(PyObject* obj,
                          paddle::framework::AttributeMap& attrs,  // NOLINT
533 534
                          const std::string& key,
                          const std::string& op_type,
535 536 537 538 539 540 541 542
                          ssize_t arg_pos) {
  attrs[key] = CastPyArg2Floats(obj, op_type, arg_pos);
}

std::vector<double> CastPyArg2Float64s(PyObject* obj,
                                       const std::string& op_type,
                                       ssize_t arg_pos) {
  std::vector<double> value;
543 544 545 546 547 548 549 550 551 552 553
  if (PyList_Check(obj)) {
    Py_ssize_t len = PyList_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyList_GetItem(obj, i);
      if (PyObject_CheckFloatOrToFloat(&item)) {
        value.emplace_back(PyFloat_AsDouble(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of float, but got %s at pos %d",
554 555
            op_type,
            arg_pos + 1,
556 557 558 559 560 561 562 563 564 565 566 567 568 569 570
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else if (PyTuple_Check(obj)) {
    Py_ssize_t len = PyTuple_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyTuple_GetItem(obj, i);
      if (PyObject_CheckFloatOrToFloat(&item)) {
        value.emplace_back(PyFloat_AsDouble(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of float, but got %s at pos %d",
571 572
            op_type,
            arg_pos + 1,
573 574 575 576 577 578 579 580 581 582 583 584 585 586 587
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else if (PySequence_Check(obj)) {
    Py_ssize_t len = PySequence_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PySequence_GetItem(obj, i);
      if (PyObject_CheckFloatOrToFloat(&item)) {
        value.emplace_back(PyFloat_AsDouble(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of float, but got %s at pos %d",
588 589
            op_type,
            arg_pos + 1,
590 591 592 593 594 595 596 597
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "list or tuple, but got %s",
598 599
        op_type,
        arg_pos + 1,
600 601
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
602 603

  return value;
604 605
}

606 607
void CastPyArg2AttrFloat64s(PyObject* obj,
                            paddle::framework::AttributeMap& attrs,  // NOLINT
608 609
                            const std::string& key,
                            const std::string& op_type,
610 611 612 613 614 615 616 617
                            ssize_t arg_pos) {
  attrs[key] = CastPyArg2Float64s(obj, op_type, arg_pos);
}

std::vector<std::string> CastPyArg2Strings(PyObject* obj,
                                           const std::string& op_type,
                                           ssize_t arg_pos) {
  std::vector<std::string> value;
618 619 620 621 622 623 624 625 626 627 628 629 630 631
  if (PyList_Check(obj)) {
    Py_ssize_t len = PyList_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyList_GetItem(obj, i);
      if (PyObject_CheckString(item)) {
        Py_ssize_t size;
        const char* data;
        data = PyUnicode_AsUTF8AndSize(item, &size);
        value.emplace_back(std::string(data, (size_t)size));  // NOLINT
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of str, but got %s at pos %d",
632 633
            op_type,
            arg_pos + 1,
634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else if (PyTuple_Check(obj)) {
    Py_ssize_t len = PyTuple_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyTuple_GetItem(obj, i);
      if (PyObject_CheckString(item)) {
        Py_ssize_t size;
        const char* data;
        data = PyUnicode_AsUTF8AndSize(item, &size);
        value.emplace_back(std::string(data, (size_t)size));  // NOLINT
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of str, but got %s at pos %d",
652 653
            op_type,
            arg_pos + 1,
654 655 656 657 658 659 660 661
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "list or tuple, but got %s",
662 663
        op_type,
        arg_pos + 1,
664 665
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
666 667 668 669 670 671

  return value;
}

void CastPyArg2AttrStrings(PyObject* obj,
                           paddle::framework::AttributeMap& attrs,  // NOLINT
672 673
                           const std::string& key,
                           const std::string& op_type,
674 675
                           ssize_t arg_pos) {
  attrs[key] = CastPyArg2Strings(obj, op_type, arg_pos);
676 677 678 679
}

void CastPyArg2AttrBlock(PyObject* obj,
                         paddle::framework::AttributeMap& attrs,  // NOLINT
680 681
                         const std::string& key,
                         const std::string& op_type,
682 683 684 685 686 687 688 689 690
                         ssize_t arg_pos) {
  ::pybind11::detail::instance* inst =
      (::pybind11::detail::instance*)obj;  // NOLINT

  if (!PyObject_IsInstance((PyObject*)inst,                   // NOLINT
                           (PyObject*)g_blockdesc_pytype)) {  // NOLINT
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "BlockDesc, but got %s",
691 692
        op_type,
        arg_pos + 1,
693 694 695 696 697 698 699 700
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
  void** vh = inst->simple_layout ? inst->simple_value_holder
                                  : &inst->nonsimple.values_and_holders[0];
  attrs[key] = reinterpret_cast<paddle::framework::BlockDesc*&>(vh[0]);
}

void ConstructAttrMapFromPyArgs(
701 702 703 704 705 706 707
    const std::string& op_type,
    PyObject* args,
    ssize_t attr_start,
    ssize_t attr_end,
    paddle::framework::AttributeMap& attrs) {  // NOLINT
  PADDLE_ENFORCE_EQ((attr_end - attr_start) % 2,
                    0,
708 709 710
                    platform::errors::InvalidArgument(
                        "The number of arguments for attributes should be even "
                        "but attr_start = %d, attr_end = %d.",
711 712
                        attr_start,
                        attr_end));
713 714 715 716 717 718 719 720 721 722 723 724 725 726

  auto attr_type_map = &(OpAttrTypeMap::Instance().Map()[op_type]);

  PyObject* obj = nullptr;
  for (ssize_t arg_pos = attr_start; arg_pos < attr_end; arg_pos += 2) {
    Py_ssize_t key_len;
    const char* key_ptr;
    obj = PyTuple_GET_ITEM(args, arg_pos);
    if (PyObject_CheckString(obj)) {
      key_ptr = PyUnicode_AsUTF8AndSize(obj, &key_len);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "%s(): argument (position %d) must be str, but got "
          "%s",
727 728 729
          op_type,
          arg_pos,
          ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746
    }

    std::string key(key_ptr, (size_t)key_len);  // NOLINT
    auto iter = attr_type_map->find(key);
    if (iter == attr_type_map->end()) {
      continue;
    }

    obj = PyTuple_GET_ITEM(args, arg_pos + 1);

    switch (iter->second) {
      case paddle::framework::proto::AttrType::INT:
        CastPyArg2AttrInt(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::FLOAT:
        CastPyArg2AttrFloat(obj, attrs, key, op_type, arg_pos);
        break;
747 748 749
      case paddle::framework::proto::AttrType::FLOAT64:
        CastPyArg2AttrDouble(obj, attrs, key, op_type, arg_pos);
        break;
750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786
      case paddle::framework::proto::AttrType::STRING:
        CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::INTS:
        CastPyArg2AttrInts(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::FLOATS:
        CastPyArg2AttrFloats(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::STRINGS:
        CastPyArg2AttrStrings(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::BOOLEAN:
        CastPyArg2AttrBoolean(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::BOOLEANS:
        CastPyArg2AttrBooleans(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::LONG:
        CastPyArg2AttrLong(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::LONGS:
        CastPyArg2AttrLongs(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::FLOAT64S:
        CastPyArg2AttrFloat64s(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::BLOCK:
        CastPyArg2AttrBlock(obj, attrs, key, op_type, arg_pos);
        break;
      default:
        break;
    }
  }
}

std::shared_ptr<imperative::VarBase> GetVarBaseFromArgs(
787 788 789 790 791
    const std::string& op_type,
    const std::string& arg_name,
    PyObject* args,
    ssize_t arg_idx,
    bool dispensable) {
792 793 794 795 796 797 798 799 800 801 802
  ::pybind11::detail::instance* inst =
      (::pybind11::detail::instance*)PyTuple_GET_ITEM(args, arg_idx);

  if (PyTuple_Check((PyObject*)inst)) {  // NOLINT
    inst = (::pybind11::detail::instance*)PyTuple_GET_ITEM(inst, 0);
  }

  if (inst == nullptr || (PyObject*)inst == Py_None) {  // NOLINT
    if (!dispensable) {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "%s(): argument '%s' (position %d) must be Tensor, but got None",
803 804 805
          op_type,
          arg_name,
          arg_idx));
806 807 808 809 810 811 812 813 814
    }
    return nullptr;
  }

  if (!PyObject_IsInstance((PyObject*)inst,                 // NOLINT
                           (PyObject*)g_varbase_pytype)) {  // NOLINT
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument '%s' (position %d) must be Tensor, but got "
        "%s",
815 816 817
        op_type,
        arg_name,
        arg_idx,
818 819 820 821 822 823 824 825 826
        ((PyTypeObject*)((PyObject*)inst)->ob_type)->tp_name));  // NOLINT
  }

  void** vh = inst->simple_layout ? inst->simple_value_holder
                                  : &inst->nonsimple.values_and_holders[0];
  return reinterpret_cast<std::shared_ptr<paddle::imperative::VarBase>&>(vh[1]);
}

std::vector<std::shared_ptr<imperative::VarBase>> GetVarBaseListFromArgs(
827 828 829 830 831
    const std::string& op_type,
    const std::string& arg_name,
    PyObject* args,
    ssize_t arg_idx,
    bool dispensable) {
832 833
  PyObject* list = PyTuple_GET_ITEM(args, arg_idx);

834
  if (list == nullptr || list == Py_None) {
835 836 837 838
    if (!dispensable) {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "%s(): argument '%s' (position %d) must be list of Tensor, but got "
          "None",
839 840 841
          op_type,
          arg_name,
          arg_idx));  // NOLINT
842 843 844 845 846 847 848 849 850 851 852 853
    }
    return {};
  }

  std::vector<std::shared_ptr<imperative::VarBase>> result;

  if (PyList_Check(list)) {
    Py_ssize_t len = PyList_Size(list);
    if (len == 0) {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "%s(): argument '%s' (position %d) must be list of Tensors, but got "
          "empty list",
854 855 856
          op_type,
          arg_name,
          arg_idx));
857 858 859 860 861 862 863 864 865 866
    }
    ::pybind11::detail::instance* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = (::pybind11::detail::instance*)PyList_GetItem(list, i);
      if (!PyObject_IsInstance((PyObject*)item,                 // NOLINT
                               (PyObject*)g_varbase_pytype)) {  // NOLINT
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument '%s' (position %d) must be list of Tensors, but "
            "got list of "
            "%s",
867 868 869
            op_type,
            arg_name,
            arg_idx,
870 871 872 873 874 875 876 877 878 879 880 881 882 883
            ((PyTypeObject*)((PyObject*)item)->ob_type)->tp_name));  // NOLINT
      }
      void** vh = item->simple_layout ? item->simple_value_holder
                                      : &item->nonsimple.values_and_holders[0];
      result.emplace_back(
          reinterpret_cast<std::shared_ptr<paddle::imperative::VarBase>&>(
              vh[1]));
    }
  } else if (PyTuple_Check(list)) {
    Py_ssize_t len = PyTuple_Size(list);
    if (len == 0) {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "%s(): argument '%s' (position %d) must be list of Tensors, but got "
          "empty list",
884 885 886
          op_type,
          arg_name,
          arg_idx));
887 888 889 890 891 892 893 894 895 896
    }
    ::pybind11::detail::instance* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = (::pybind11::detail::instance*)PyTuple_GetItem(list, i);  // NOLINT
      if (!PyObject_IsInstance((PyObject*)item,                        // NOLINT
                               (PyObject*)g_varbase_pytype)) {         // NOLINT
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument '%s' (position %d) must be list of Tensors, but "
            "got list of "
            "%s",
897 898 899
            op_type,
            arg_name,
            arg_idx,
900 901 902 903 904 905 906 907 908 909 910 911
            ((PyTypeObject*)((PyObject*)item)->ob_type)->tp_name));  // NOLINT
      }
      void** vh = item->simple_layout ? item->simple_value_holder
                                      : &item->nonsimple.values_and_holders[0];
      result.emplace_back(
          reinterpret_cast<std::shared_ptr<paddle::imperative::VarBase>&>(
              vh[1]));
    }
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument '%s' (position %d) must be list of Tensors, but got "
        "%s",
912 913 914
        op_type,
        arg_name,
        arg_idx,
915 916 917 918 919 920 921
        ((PyTypeObject*)list->ob_type)->tp_name));  // NOLINT
  }

  return result;
}

unsigned long GetUnsignedLongFromArgs(  // NOLINT
922 923 924 925 926
    const std::string& op_type,
    const std::string& arg_name,
    PyObject* args,
    ssize_t arg_idx,
    bool dispensable) {
927 928 929 930 931 932
  PyObject* item = PyTuple_GET_ITEM(args, arg_idx);

  if (item == nullptr) {
    if (!dispensable) {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "%s(): argument '%s' (position %d) must be long, but got None",
933 934 935
          op_type,
          arg_name,
          arg_idx));
936 937 938 939 940 941 942 943 944 945
    }
    return 0;
  }

  if (PyObject_CheckLongOrToLong(&item)) {
    return PyLong_AsUnsignedLong(item);
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument '%s' (position %d) must be "
        "long, but got %s",
946 947 948
        op_type,
        arg_name,
        arg_idx,
949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964
        ((PyTypeObject*)item->ob_type)->tp_name));  // NOLINT
  }
}

void InitOpsAttrTypeMap() {
  auto op_info_map = paddle::framework::OpInfoMap::Instance().map();
  for (auto iter = op_info_map.begin(); iter != op_info_map.end(); ++iter) {
    auto op_proto = iter->second.proto_;
    if (op_proto == nullptr) {
      continue;
    }
    auto attrs_proto = op_proto->attrs();
    for (auto& attr : attrs_proto) {
      OpAttrTypeMap::Instance().Map()[iter->first][attr.name()] = attr.type();
    }
  }
965 966 967 968 969 970 971 972 973
  const auto& extra_attr_maps =
      operators::ExtraInfoUtils::Instance().GetAllExtraAttrsMap();
  for (const auto& extra_attrs : extra_attr_maps) {
    for (auto& attr : extra_attrs.second) {
      OpAttrTypeMap::Instance().Map()[extra_attrs.first][attr.first] =
          static_cast<paddle::framework::proto::AttrType>(attr.second.index() -
                                                          1);
    }
  }
974 975
}

976 977 978
ssize_t GetIdxFromCoreOpsInfoMap(
    const std::unordered_map<std::string, std::vector<std::string>>&
        core_ops_info_map,
979 980
    const std::string& op_type,
    const std::string& name) {
981 982 983 984 985 986 987 988 989 990 991 992
  // `core_ops_info_map` can be `core_ops_args_info` or `core_ops_returns_info`.
  // `core_ops_args_info`: get index from core_ops_args_info[op_type] according
  // to input name.
  // `core_ops_returns_info`: get index from core_ops_returns_info[op_type]
  // according to return name.
  if (!core_ops_info_map.count(op_type)) {
    PADDLE_THROW(platform::errors::Fatal(
        "Op %s is not found in core_ops_*_info map.", op_type));
  } else {
    auto args_list = core_ops_info_map.at(op_type);
    auto it = std::find(args_list.begin(), args_list.end(), name);
    if (it == args_list.end()) {
993 994
      PADDLE_THROW(platform::errors::Fatal(
          "%s is not found in op %s's args.", name, op_type));
995 996 997 998 999 1000 1001
    } else {
      return std::distance(args_list.begin(), it);
    }
  }
  return -1;
}

1002 1003
}  // namespace pybind
}  // namespace paddle