op_function_common.cc 33.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/pybind/op_function_common.h"

17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>

#include <memory>
#include <string>
#include <vector>

#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
32
#include "paddle/fluid/operators/ops_extra_info.h"
33
#include "paddle/fluid/pybind/eager.h"
34
#include "paddle/fluid/pybind/eager_utils.h"
35
#include "paddle/fluid/pybind/imperative.h"
36
#include "paddle/phi/common/complex.h"
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64

namespace paddle {
namespace pybind {

class OpAttrTypeMap {
 public:
  static OpAttrTypeMap& Instance() {
    static OpAttrTypeMap g_op_attr_type_map;
    return g_op_attr_type_map;
  }

  std::unordered_map<
      std::string,
      std::unordered_map<std::string, paddle::framework::proto::AttrType>>&
  Map() {
    return ops_attrtype_map_;
  }

 private:
  OpAttrTypeMap() = default;
  std::unordered_map<
      std::string,
      std::unordered_map<std::string, paddle::framework::proto::AttrType>>
      ops_attrtype_map_;
};

extern PyTypeObject* g_vartype_pytype;
extern PyTypeObject* g_blockdesc_pytype;
65
extern PyTypeObject* p_tensor_type;
66 67 68 69 70

bool PyObject_CheckBool(PyObject** obj) { return PyBool_Check(*obj); }

bool PyObject_CheckLongOrToLong(PyObject** obj) {
  if ((PyLong_Check(*obj) && !PyBool_Check(*obj)) ||
71 72 73
      PyObject_TypeCheck(*obj, g_vartype_pytype) ||        // NOLINT
      (PyObject_TypeCheck(*obj, p_tensor_type) &&          // NOLINT
       (((TensorObject*)(*obj))->tensor.numel() == 1))) {  // NOLINT
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    return true;
  }

  if (std::string(((PyTypeObject*)(*obj)->ob_type)->tp_name)  // NOLINT
          .find("numpy") != std::string::npos) {
    auto to = PyNumber_Long(*obj);
    if (to) {
      *obj = to;
      return true;
    }
  }

  return false;
}

bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
  // sometimes users provide PyLong or numpy.int64 but attr is float
  if (PyFloat_Check(*obj) || PyLong_Check(*obj) ||
92 93
      (PyObject_TypeCheck(*obj, p_tensor_type) &&          // NOLINT
       (((TensorObject*)(*obj))->tensor.numel() == 1))) {  // NOLINT
94 95 96 97 98 99 100 101 102 103 104 105 106
    return true;
  }
  if (std::string(((PyTypeObject*)(*obj)->ob_type)->tp_name)  // NOLINT
          .find("numpy") != std::string::npos) {
    auto to = PyNumber_Float(*obj);
    if (to) {
      *obj = to;
      return true;
    }
  }
  return false;
}

107 108
bool PyObject_CheckComplexOrToComplex(PyObject** obj) {
  if (PyComplex_Check(*obj) || PyLong_Check(*obj) || PyFloat_Check(*obj) ||
109 110
      PyObject_TypeCheck(*obj, g_vartype_pytype) ||  // NOLINT
      PyObject_TypeCheck(*obj, p_tensor_type)) {     // NOLINT
111 112 113 114 115 116
    return true;
  }
  // consider numpy cfloat & numpy cdouble?
  return false;
}

117 118
bool PyObject_CheckString(PyObject* obj) { return PyUnicode_Check(obj); }

119 120
bool CastPyArg2Boolean(PyObject* obj,
                       const std::string& op_type,
121
                       ssize_t arg_pos) {
122
  if (obj == Py_None) {
123 124
    return false;  // To be compatible with QA integration testing. Some
                   // test case pass in None.
125
  } else if (obj == Py_True) {
126
    return true;
127
  } else if (obj == Py_False) {
128
    return false;
129 130 131 132
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "bool, but got %s",
133 134
        op_type,
        arg_pos + 1,
135 136
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
137 138 139 140 141 142

  return false;
}

void CastPyArg2AttrBoolean(PyObject* obj,
                           paddle::framework::AttributeMap& attrs,  // NOLINT
143 144
                           const std::string& key,
                           const std::string& op_type,
145 146 147 148 149 150 151 152 153 154 155
                           ssize_t arg_pos) {
  attrs[key] = CastPyArg2Boolean(obj, op_type, arg_pos);
}

int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos) {
  if (PyObject_CheckLongOrToLong(&obj)) {
    return (int)PyLong_AsLong(obj);  // NOLINT
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "int, but got %s",
156 157
        op_type,
        arg_pos + 1,
158 159 160 161
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }

  return 0;
162 163 164 165
}

void CastPyArg2AttrInt(PyObject* obj,
                       paddle::framework::AttributeMap& attrs,  // NOLINT
166 167
                       const std::string& key,
                       const std::string& op_type,
168
                       ssize_t arg_pos) {
169 170 171
  attrs[key] = CastPyArg2Int(obj, op_type, arg_pos);
}

172 173
int64_t CastPyArg2Long(PyObject* obj,
                       const std::string& op_type,
174
                       ssize_t arg_pos) {
175
  if (PyObject_CheckLongOrToLong(&obj)) {
176
    return (int64_t)PyLong_AsLongLong(obj);  // NOLINT
177 178 179
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
180
        "long, but got %s",
181 182
        op_type,
        arg_pos + 1,
183 184
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
185 186

  return 0;
187 188 189 190
}

void CastPyArg2AttrLong(PyObject* obj,
                        paddle::framework::AttributeMap& attrs,  // NOLINT
191 192
                        const std::string& key,
                        const std::string& op_type,
193
                        ssize_t arg_pos) {
194 195 196
  attrs[key] = CastPyArg2Long(obj, op_type, arg_pos);
}

197 198 199 200 201 202 203 204
void CastPyArg2AttrScalar(PyObject* obj,
                          paddle::framework::AttributeMap& attrs,  // NOLINT
                          const std::string& key,
                          const std::string& op_type,
                          ssize_t arg_pos) {
  attrs[key] = CastPyArg2Scalar(obj, op_type, arg_pos);
}

205 206 207 208 209 210
float16 CastPyArg2Float16(PyObject* obj,
                          const std::string& op_type,
                          ssize_t arg_pos) {
  return static_cast<float16>(CastPyArg2Double(obj, op_type, arg_pos));
}

211 212
float CastPyArg2Float(PyObject* obj,
                      const std::string& op_type,
213
                      ssize_t arg_pos) {
214 215 216
  return static_cast<float>(CastPyArg2Double(obj, op_type, arg_pos));
}

217 218 219 220 221 222 223 224
void CastPyArg2AttrFloat(PyObject* obj,
                         paddle::framework::AttributeMap& attrs,  // NOLINT
                         const std::string& key,
                         const std::string& op_type,
                         ssize_t arg_pos) {
  attrs[key] = CastPyArg2Float(obj, op_type, arg_pos);
}

225 226
double CastPyArg2Double(PyObject* obj,
                        const std::string& op_type,
227
                        ssize_t arg_pos) {
228
  if (PyObject_CheckFloatOrToFloat(&obj)) {
229
    return PyFloat_AsDouble(obj);  // NOLINT
230 231 232
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
233
        "double, but got %s",
234 235
        op_type,
        arg_pos + 1,
236 237
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
238 239

  return 0.0;
240 241
}

242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
phi::dtype::complex<float> CastPyArg2Complex(PyObject* obj,
                                             const std::string& op_type,
                                             ssize_t arg_pos) {
  if (PyComplex_Check(obj)) {
    double real = PyComplex_RealAsDouble(obj);
    double imag = PyComplex_ImagAsDouble(obj);
    return phi::dtype::complex<float>(real, imag);
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "complex, but got %s",
        op_type,
        arg_pos + 1,
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }

  return phi::dtype::complex<float>(0, 0);
}

261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
phi::dtype::complex<double> CastPyArg2Complex128(PyObject* obj,
                                                 const std::string& op_type,
                                                 ssize_t arg_pos) {
  if (PyComplex_Check(obj)) {
    double real = PyComplex_RealAsDouble(obj);
    double imag = PyComplex_ImagAsDouble(obj);
    return phi::dtype::complex<double>(real, imag);
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "complex, but got %s",
        op_type,
        arg_pos + 1,
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }

  return phi::dtype::complex<double>(0, 0);
}

280 281 282 283 284 285
void CastPyArg2AttrDouble(PyObject* obj,
                          paddle::framework::AttributeMap& attrs,  // NOLINT
                          const std::string& key,
                          const std::string& op_type,
                          ssize_t arg_pos) {
  attrs[key] = CastPyArg2Double(obj, op_type, arg_pos);
286 287
}

288 289
std::string CastPyArg2String(PyObject* obj,
                             const std::string& op_type,
290
                             ssize_t arg_pos) {
291 292 293 294
  if (PyObject_CheckString(obj)) {
    Py_ssize_t size;
    const char* data;
    data = PyUnicode_AsUTF8AndSize(obj, &size);
295
    return std::string(data, (size_t)size);  // NOLINT
296 297 298 299
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "str, but got %s",
300 301
        op_type,
        arg_pos + 1,
302 303
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
304 305

  return "";
306 307
}

308 309
void CastPyArg2AttrString(PyObject* obj,
                          paddle::framework::AttributeMap& attrs,  // NOLINT
310 311
                          const std::string& key,
                          const std::string& op_type,
312 313 314 315
                          ssize_t arg_pos) {
  attrs[key] = CastPyArg2String(obj, op_type, arg_pos);
}

316 317
std::vector<bool> CastPyArg2Booleans(PyObject* obj,
                                     const std::string& op_type,
318 319
                                     ssize_t arg_pos) {
  std::vector<bool> value;
320 321 322 323 324 325 326 327 328 329 330
  if (PyList_Check(obj)) {
    Py_ssize_t len = PyList_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyList_GetItem(obj, i);
      if (PyObject_CheckBool(&item)) {
        value.emplace_back(PyLong_AsLong(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of bool, but got %s at pos %d",
331 332
            op_type,
            arg_pos + 1,
333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else if (PyTuple_Check(obj)) {
    Py_ssize_t len = PyTuple_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyTuple_GetItem(obj, i);
      if (PyObject_CheckBool(&item)) {
        value.emplace_back(PyLong_AsLong(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of bool, but got %s at pos %d",
348 349
            op_type,
            arg_pos + 1,
350 351 352 353 354 355 356 357
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "list or tuple, but got %s",
358 359
        op_type,
        arg_pos + 1,
360 361
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
362 363

  return value;
364 365
}

366 367
void CastPyArg2AttrBooleans(PyObject* obj,
                            paddle::framework::AttributeMap& attrs,  // NOLINT
368 369
                            const std::string& key,
                            const std::string& op_type,
370 371 372 373
                            ssize_t arg_pos) {
  attrs[key] = CastPyArg2Booleans(obj, op_type, arg_pos);
}

374 375
std::vector<int> CastPyArg2Ints(PyObject* obj,
                                const std::string& op_type,
376 377
                                ssize_t arg_pos) {
  std::vector<int> value;
378 379
  if (PyList_Check(obj)) {
    Py_ssize_t len = PyList_Size(obj);
Z
zyfncg 已提交
380
    value.reserve(len);
381 382 383 384 385 386 387 388 389
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyList_GetItem(obj, i);
      if (PyObject_CheckLongOrToLong(&item)) {
        value.emplace_back(PyLong_AsLong(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of int, but got %s at pos %d",
390 391
            op_type,
            arg_pos + 1,
392 393 394 395 396 397
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else if (PyTuple_Check(obj)) {
    Py_ssize_t len = PyTuple_Size(obj);
Z
zyfncg 已提交
398
    value.reserve(len);
399 400 401 402 403 404 405 406 407
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyTuple_GetItem(obj, i);
      if (PyObject_CheckLongOrToLong(&item)) {
        value.emplace_back(PyLong_AsLong(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of int, but got %s at pos %d",
408 409
            op_type,
            arg_pos + 1,
410 411 412 413
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
W
wanghuancoder 已提交
414
  } else if (PySequence_Check(obj) && !PyObject_TypeCheck(obj, p_tensor_type)) {
415
    Py_ssize_t len = PySequence_Size(obj);
Z
zyfncg 已提交
416
    value.reserve(len);
417 418 419 420 421 422 423 424 425
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PySequence_GetItem(obj, i);
      if (PyObject_CheckLongOrToLong(&item)) {
        value.emplace_back(PyLong_AsLong(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of int, but got %s at pos %d",
426 427
            op_type,
            arg_pos + 1,
428 429 430 431 432 433 434 435
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "list or tuple, but got %s",
436 437
        op_type,
        arg_pos + 1,
438 439
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
440 441

  return value;
442 443
}

444 445
void CastPyArg2AttrInts(PyObject* obj,
                        paddle::framework::AttributeMap& attrs,  // NOLINT
446 447
                        const std::string& key,
                        const std::string& op_type,
448 449 450 451
                        ssize_t arg_pos) {
  attrs[key] = CastPyArg2Ints(obj, op_type, arg_pos);
}

452 453
std::vector<int64_t> CastPyArg2Longs(PyObject* obj,
                                     const std::string& op_type,
454 455
                                     ssize_t arg_pos) {
  std::vector<int64_t> value;
456 457 458 459 460 461
  if (PyList_Check(obj)) {
    Py_ssize_t len = PyList_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyList_GetItem(obj, i);
      if (PyObject_CheckLongOrToLong(&item)) {
462
        value.emplace_back((int64_t)PyLong_AsLongLong(item));
463 464 465 466
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of int, but got %s at pos %d",
467 468
            op_type,
            arg_pos + 1,
469 470 471 472 473 474 475 476 477 478
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else if (PyTuple_Check(obj)) {
    Py_ssize_t len = PyTuple_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyTuple_GetItem(obj, i);
      if (PyObject_CheckLongOrToLong(&item)) {
479
        value.emplace_back((int64_t)PyLong_AsLongLong(item));
480 481 482 483
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of int, but got %s at pos %d",
484 485
            op_type,
            arg_pos + 1,
486 487 488 489
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
W
wanghuancoder 已提交
490
  } else if (PySequence_Check(obj) && !PyObject_TypeCheck(obj, p_tensor_type)) {
491 492 493 494 495
    Py_ssize_t len = PySequence_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PySequence_GetItem(obj, i);
      if (PyObject_CheckLongOrToLong(&item)) {
496
        value.emplace_back((int64_t)PyLong_AsLongLong(item));
497 498 499 500
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of int, but got %s at pos %d",
501 502
            op_type,
            arg_pos + 1,
503 504 505 506
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
507 508 509
  } else if (obj == Py_None) {
    return {};
  } else if (PyObject_CheckLongOrToLong(&obj)) {
510
    return {(int64_t)PyLong_AsLongLong(obj)};
511
  } else {
512 513 514
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "list or tuple, but got %s",
515 516
        op_type,
        arg_pos + 1,
517 518
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
519 520

  return value;
521 522
}

523 524
void CastPyArg2AttrLongs(PyObject* obj,
                         paddle::framework::AttributeMap& attrs,  // NOLINT
525 526
                         const std::string& key,
                         const std::string& op_type,
527 528 529 530
                         ssize_t arg_pos) {
  attrs[key] = CastPyArg2Longs(obj, op_type, arg_pos);
}

531 532
std::vector<float> CastPyArg2Floats(PyObject* obj,
                                    const std::string& op_type,
533 534
                                    ssize_t arg_pos) {
  std::vector<float> value;
535 536 537 538 539 540 541 542 543 544 545
  if (PyList_Check(obj)) {
    Py_ssize_t len = PyList_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyList_GetItem(obj, i);
      if (PyObject_CheckFloatOrToFloat(&item)) {
        value.emplace_back(PyFloat_AsDouble(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of float, but got %s at pos %d",
546 547
            op_type,
            arg_pos + 1,
548 549 550 551 552 553 554 555 556 557 558 559 560 561 562
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else if (PyTuple_Check(obj)) {
    Py_ssize_t len = PyTuple_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyTuple_GetItem(obj, i);
      if (PyObject_CheckFloatOrToFloat(&item)) {
        value.emplace_back(PyFloat_AsDouble(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of float, but got %s at pos %d",
563 564
            op_type,
            arg_pos + 1,
565 566 567 568
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
W
wanghuancoder 已提交
569
  } else if (PySequence_Check(obj) && !PyObject_TypeCheck(obj, p_tensor_type)) {
570 571 572 573 574 575 576 577 578 579
    Py_ssize_t len = PySequence_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PySequence_GetItem(obj, i);
      if (PyObject_CheckFloatOrToFloat(&item)) {
        value.emplace_back(PyFloat_AsDouble(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of float, but got %s at pos %d",
580 581
            op_type,
            arg_pos + 1,
582 583 584 585 586 587 588 589
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "list or tuple, but got %s",
590 591
        op_type,
        arg_pos + 1,
592 593
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
594 595

  return value;
596 597
}

598 599
void CastPyArg2AttrFloats(PyObject* obj,
                          paddle::framework::AttributeMap& attrs,  // NOLINT
600 601
                          const std::string& key,
                          const std::string& op_type,
602 603 604 605 606 607 608 609
                          ssize_t arg_pos) {
  attrs[key] = CastPyArg2Floats(obj, op_type, arg_pos);
}

std::vector<double> CastPyArg2Float64s(PyObject* obj,
                                       const std::string& op_type,
                                       ssize_t arg_pos) {
  std::vector<double> value;
610 611 612 613 614 615 616 617 618 619 620
  if (PyList_Check(obj)) {
    Py_ssize_t len = PyList_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyList_GetItem(obj, i);
      if (PyObject_CheckFloatOrToFloat(&item)) {
        value.emplace_back(PyFloat_AsDouble(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of float, but got %s at pos %d",
621 622
            op_type,
            arg_pos + 1,
623 624 625 626 627 628 629 630 631 632 633 634 635 636 637
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else if (PyTuple_Check(obj)) {
    Py_ssize_t len = PyTuple_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyTuple_GetItem(obj, i);
      if (PyObject_CheckFloatOrToFloat(&item)) {
        value.emplace_back(PyFloat_AsDouble(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of float, but got %s at pos %d",
638 639
            op_type,
            arg_pos + 1,
640 641 642 643
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
W
wanghuancoder 已提交
644
  } else if (PySequence_Check(obj) && !PyObject_TypeCheck(obj, p_tensor_type)) {
645 646 647 648 649 650 651 652 653 654
    Py_ssize_t len = PySequence_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PySequence_GetItem(obj, i);
      if (PyObject_CheckFloatOrToFloat(&item)) {
        value.emplace_back(PyFloat_AsDouble(item));
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of float, but got %s at pos %d",
655 656
            op_type,
            arg_pos + 1,
657 658 659 660 661 662 663 664
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "list or tuple, but got %s",
665 666
        op_type,
        arg_pos + 1,
667 668
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
669 670

  return value;
671 672
}

673 674
void CastPyArg2AttrFloat64s(PyObject* obj,
                            paddle::framework::AttributeMap& attrs,  // NOLINT
675 676
                            const std::string& key,
                            const std::string& op_type,
677 678 679 680
                            ssize_t arg_pos) {
  attrs[key] = CastPyArg2Float64s(obj, op_type, arg_pos);
}

681 682 683 684 685 686 687 688
void CastPyArg2AttrScalars(PyObject* obj,
                           paddle::framework::AttributeMap& attrs,  // NOLINT
                           const std::string& key,
                           const std::string& op_type,
                           ssize_t arg_pos) {
  attrs[key] = CastPyArg2Scalars(obj, op_type, arg_pos);
}

689 690 691 692
std::vector<std::string> CastPyArg2Strings(PyObject* obj,
                                           const std::string& op_type,
                                           ssize_t arg_pos) {
  std::vector<std::string> value;
693 694 695 696 697 698 699 700 701 702 703 704 705 706
  if (PyList_Check(obj)) {
    Py_ssize_t len = PyList_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyList_GetItem(obj, i);
      if (PyObject_CheckString(item)) {
        Py_ssize_t size;
        const char* data;
        data = PyUnicode_AsUTF8AndSize(item, &size);
        value.emplace_back(std::string(data, (size_t)size));  // NOLINT
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of str, but got %s at pos %d",
707 708
            op_type,
            arg_pos + 1,
709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else if (PyTuple_Check(obj)) {
    Py_ssize_t len = PyTuple_Size(obj);
    PyObject* item = nullptr;
    for (Py_ssize_t i = 0; i < len; i++) {
      item = PyTuple_GetItem(obj, i);
      if (PyObject_CheckString(item)) {
        Py_ssize_t size;
        const char* data;
        data = PyUnicode_AsUTF8AndSize(item, &size);
        value.emplace_back(std::string(data, (size_t)size));  // NOLINT
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument (position %d) must be "
            "list of str, but got %s at pos %d",
727 728
            op_type,
            arg_pos + 1,
729 730 731 732 733 734 735 736
            ((PyTypeObject*)item->ob_type)->tp_name,  // NOLINT
            i));
      }
    }
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "list or tuple, but got %s",
737 738
        op_type,
        arg_pos + 1,
739 740
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
741 742 743 744 745 746

  return value;
}

void CastPyArg2AttrStrings(PyObject* obj,
                           paddle::framework::AttributeMap& attrs,  // NOLINT
747 748
                           const std::string& key,
                           const std::string& op_type,
749 750
                           ssize_t arg_pos) {
  attrs[key] = CastPyArg2Strings(obj, op_type, arg_pos);
751 752
}

753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810
std::vector<paddle::experimental::Scalar> CastPyArg2Scalars(
    PyObject* obj, const std::string& op_type, ssize_t arg_pos) {
  if (obj == Py_None) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "a list of int, float, or bool, but got %s",
        op_type,
        arg_pos + 1,
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }

  PyTypeObject* type = obj->ob_type;
  auto type_name = std::string(type->tp_name);
  VLOG(4) << "type_name: " << type_name;
  if (PyList_Check(obj)) {
    Py_ssize_t len = PyList_Size(obj);
    PyObject* item = nullptr;
    item = PyList_GetItem(obj, 0);
    if (PyObject_CheckFloatOrToFloat(&item)) {
      std::vector<paddle::experimental::Scalar> value;
      for (Py_ssize_t i = 0; i < len; i++) {
        item = PyList_GetItem(obj, i);
        value.emplace_back(
            paddle::experimental::Scalar{PyFloat_AsDouble(item)});
      }
      return value;
    } else if (PyObject_CheckLongOrToLong(&item)) {
      std::vector<paddle::experimental::Scalar> value;
      for (Py_ssize_t i = 0; i < len; i++) {
        item = PyList_GetItem(obj, i);
        value.emplace_back(paddle::experimental::Scalar{
            static_cast<int64_t>(PyLong_AsLong(item))});
      }
      return value;
    } else if (PyObject_CheckComplexOrToComplex(&item)) {
      std::vector<paddle::experimental::Scalar> value;
      for (Py_ssize_t i = 0; i < len; i++) {
        item = PyList_GetItem(obj, i);
        Py_complex v = PyComplex_AsCComplex(item);
        value.emplace_back(
            paddle::experimental::Scalar{std::complex<double>(v.real, v.imag)});
      }
      return value;
    }
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "a list of int, float, complex, or bool, but got %s",
        op_type,
        arg_pos + 1,
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }

  // Fake a ScalarArray
  return std::vector<paddle::experimental::Scalar>(
      {paddle::experimental::Scalar(1.0)});
}

811 812
void CastPyArg2AttrBlock(PyObject* obj,
                         paddle::framework::AttributeMap& attrs,  // NOLINT
813 814
                         const std::string& key,
                         const std::string& op_type,
815 816 817 818
                         ssize_t arg_pos) {
  ::pybind11::detail::instance* inst =
      (::pybind11::detail::instance*)obj;  // NOLINT

819
  if (!PyObject_TypeCheck((PyObject*)inst, g_blockdesc_pytype)) {  // NOLINT
820 821 822
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument (position %d) must be "
        "BlockDesc, but got %s",
823 824
        op_type,
        arg_pos + 1,
825 826 827 828 829 830 831 832
        ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
  }
  void** vh = inst->simple_layout ? inst->simple_value_holder
                                  : &inst->nonsimple.values_and_holders[0];
  attrs[key] = reinterpret_cast<paddle::framework::BlockDesc*&>(vh[0]);
}

void ConstructAttrMapFromPyArgs(
833 834 835 836 837 838 839
    const std::string& op_type,
    PyObject* args,
    ssize_t attr_start,
    ssize_t attr_end,
    paddle::framework::AttributeMap& attrs) {  // NOLINT
  PADDLE_ENFORCE_EQ((attr_end - attr_start) % 2,
                    0,
840 841 842
                    platform::errors::InvalidArgument(
                        "The number of arguments for attributes should be even "
                        "but attr_start = %d, attr_end = %d.",
843 844
                        attr_start,
                        attr_end));
845 846 847 848 849 850 851 852 853 854 855 856 857 858

  auto attr_type_map = &(OpAttrTypeMap::Instance().Map()[op_type]);

  PyObject* obj = nullptr;
  for (ssize_t arg_pos = attr_start; arg_pos < attr_end; arg_pos += 2) {
    Py_ssize_t key_len;
    const char* key_ptr;
    obj = PyTuple_GET_ITEM(args, arg_pos);
    if (PyObject_CheckString(obj)) {
      key_ptr = PyUnicode_AsUTF8AndSize(obj, &key_len);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "%s(): argument (position %d) must be str, but got "
          "%s",
859 860 861
          op_type,
          arg_pos,
          ((PyTypeObject*)obj->ob_type)->tp_name));  // NOLINT
862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878
    }

    std::string key(key_ptr, (size_t)key_len);  // NOLINT
    auto iter = attr_type_map->find(key);
    if (iter == attr_type_map->end()) {
      continue;
    }

    obj = PyTuple_GET_ITEM(args, arg_pos + 1);

    switch (iter->second) {
      case paddle::framework::proto::AttrType::INT:
        CastPyArg2AttrInt(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::FLOAT:
        CastPyArg2AttrFloat(obj, attrs, key, op_type, arg_pos);
        break;
879 880 881
      case paddle::framework::proto::AttrType::FLOAT64:
        CastPyArg2AttrDouble(obj, attrs, key, op_type, arg_pos);
        break;
882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911
      case paddle::framework::proto::AttrType::STRING:
        CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::INTS:
        CastPyArg2AttrInts(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::FLOATS:
        CastPyArg2AttrFloats(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::STRINGS:
        CastPyArg2AttrStrings(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::BOOLEAN:
        CastPyArg2AttrBoolean(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::BOOLEANS:
        CastPyArg2AttrBooleans(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::LONG:
        CastPyArg2AttrLong(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::LONGS:
        CastPyArg2AttrLongs(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::FLOAT64S:
        CastPyArg2AttrFloat64s(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::BLOCK:
        CastPyArg2AttrBlock(obj, attrs, key, op_type, arg_pos);
        break;
912 913 914 915 916 917
      case paddle::framework::proto::AttrType::SCALAR:
        CastPyArg2AttrScalar(obj, attrs, key, op_type, arg_pos);
        break;
      case paddle::framework::proto::AttrType::SCALARS:
        CastPyArg2AttrScalars(obj, attrs, key, op_type, arg_pos);
        break;
918 919 920 921 922 923 924
      default:
        break;
    }
  }
}

unsigned long GetUnsignedLongFromArgs(  // NOLINT
925 926 927 928 929
    const std::string& op_type,
    const std::string& arg_name,
    PyObject* args,
    ssize_t arg_idx,
    bool dispensable) {
930 931 932 933 934 935
  PyObject* item = PyTuple_GET_ITEM(args, arg_idx);

  if (item == nullptr) {
    if (!dispensable) {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "%s(): argument '%s' (position %d) must be long, but got None",
936 937 938
          op_type,
          arg_name,
          arg_idx));
939 940 941 942 943 944 945 946 947 948
    }
    return 0;
  }

  if (PyObject_CheckLongOrToLong(&item)) {
    return PyLong_AsUnsignedLong(item);
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument '%s' (position %d) must be "
        "long, but got %s",
949 950 951
        op_type,
        arg_name,
        arg_idx,
952 953 954 955 956 957
        ((PyTypeObject*)item->ob_type)->tp_name));  // NOLINT
  }
}

void InitOpsAttrTypeMap() {
  auto op_info_map = paddle::framework::OpInfoMap::Instance().map();
958 959
  for (auto& item : op_info_map) {
    auto op_proto = item.second.proto_;
960 961 962 963 964
    if (op_proto == nullptr) {
      continue;
    }
    auto attrs_proto = op_proto->attrs();
    for (auto& attr : attrs_proto) {
965
      OpAttrTypeMap::Instance().Map()[item.first][attr.name()] = attr.type();
966 967
    }
  }
968 969 970 971 972 973 974 975 976
  const auto& extra_attr_maps =
      operators::ExtraInfoUtils::Instance().GetAllExtraAttrsMap();
  for (const auto& extra_attrs : extra_attr_maps) {
    for (auto& attr : extra_attrs.second) {
      OpAttrTypeMap::Instance().Map()[extra_attrs.first][attr.first] =
          static_cast<paddle::framework::proto::AttrType>(attr.second.index() -
                                                          1);
    }
  }
977 978
}

979 980 981
ssize_t GetIdxFromCoreOpsInfoMap(
    const std::unordered_map<std::string, std::vector<std::string>>&
        core_ops_info_map,
982 983
    const std::string& op_type,
    const std::string& name) {
984 985 986 987 988 989 990 991 992 993 994 995
  // `core_ops_info_map` can be `core_ops_args_info` or `core_ops_returns_info`.
  // `core_ops_args_info`: get index from core_ops_args_info[op_type] according
  // to input name.
  // `core_ops_returns_info`: get index from core_ops_returns_info[op_type]
  // according to return name.
  if (!core_ops_info_map.count(op_type)) {
    PADDLE_THROW(platform::errors::Fatal(
        "Op %s is not found in core_ops_*_info map.", op_type));
  } else {
    auto args_list = core_ops_info_map.at(op_type);
    auto it = std::find(args_list.begin(), args_list.end(), name);
    if (it == args_list.end()) {
996 997
      PADDLE_THROW(platform::errors::Fatal(
          "%s is not found in op %s's args.", name, op_type));
998 999 1000 1001 1002 1003 1004
    } else {
      return std::distance(args_list.begin(), it);
    }
  }
  return -1;
}

1005 1006
}  // namespace pybind
}  // namespace paddle