op_function_common.h 7.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17 18 19 20 21
#if defined(_MSC_VER)
#include <BaseTsd.h>
typedef SSIZE_T ssize_t;
#endif

22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>

#include <memory>
#include <string>
#include <vector>

#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/pybind/imperative.h"

namespace py = pybind11;
namespace paddle {
namespace pybind {

bool PyObject_CheckBool(PyObject** obj);

bool PyObject_CheckLongOrToLong(PyObject** obj);

bool PyObject_CheckFloatOrToFloat(PyObject** obj);

bool PyObject_CheckString(PyObject* obj);

51 52
bool CastPyArg2Boolean(PyObject* obj,
                       const std::string& op_type,
53 54
                       ssize_t arg_pos);
int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos);
55 56
int64_t CastPyArg2Long(PyObject* obj,
                       const std::string& op_type,
57
                       ssize_t arg_pos);
58 59
float CastPyArg2Float(PyObject* obj,
                      const std::string& op_type,
60
                      ssize_t arg_pos);
61 62
double CastPyArg2Double(PyObject* obj,
                        const std::string& op_type,
63
                        ssize_t arg_pos);
64 65 66
phi::dtype::complex<float> CastPyArg2Complex(PyObject* obj,
                                             const std::string& op_type,
                                             ssize_t arg_pos);
67 68
std::string CastPyArg2String(PyObject* obj,
                             const std::string& op_type,
69
                             ssize_t arg_pos);
70 71
std::vector<bool> CastPyArg2Booleans(PyObject* obj,
                                     const std::string& op_type,
72
                                     ssize_t arg_pos);
73 74
std::vector<int> CastPyArg2Ints(PyObject* obj,
                                const std::string& op_type,
75
                                ssize_t arg_pos);
76 77
std::vector<int64_t> CastPyArg2Longs(PyObject* obj,
                                     const std::string& op_type,
78
                                     ssize_t arg_pos);
79 80
std::vector<float> CastPyArg2Floats(PyObject* obj,
                                    const std::string& op_type,
81 82 83 84 85 86 87 88
                                    ssize_t arg_pos);
std::vector<double> CastPyArg2Float64s(PyObject* obj,
                                       const std::string& op_type,
                                       ssize_t arg_pos);
std::vector<std::string> CastPyArg2Strings(PyObject* obj,
                                           const std::string& op_type,
                                           ssize_t arg_pos);

89 90
void CastPyArg2AttrBoolean(PyObject* obj,
                           paddle::framework::AttributeMap& attrs,  // NOLINT
91 92
                           const std::string& key,
                           const std::string& op_type,
93 94 95 96
                           ssize_t arg_pos);

void CastPyArg2AttrInt(PyObject* obj,
                       paddle::framework::AttributeMap& attrs,  // NOLINT
97 98
                       const std::string& key,
                       const std::string& op_type,
99 100 101 102
                       ssize_t arg_pos);

void CastPyArg2AttrLong(PyObject* obj,
                        paddle::framework::AttributeMap& attrs,  // NOLINT
103 104
                        const std::string& key,
                        const std::string& op_type,
105 106 107 108
                        ssize_t arg_pos);

void CastPyArg2AttrFloat(PyObject* obj,
                         paddle::framework::AttributeMap& attrs,  // NOLINT
109 110
                         const std::string& key,
                         const std::string& op_type,
111 112
                         ssize_t arg_pos);

113 114 115 116 117 118
void CastPyArg2AttrDouble(PyObject* obj,
                          paddle::framework::AttributeMap& attrs,  // NOLINT
                          const std::string& key,
                          const std::string& op_type,
                          ssize_t arg_pos);

119 120
void CastPyArg2AttrString(PyObject* obj,
                          paddle::framework::AttributeMap& attrs,  // NOLINT
121 122
                          const std::string& key,
                          const std::string& op_type,
123 124 125 126
                          ssize_t arg_pos);

void CastPyArg2AttrBooleans(PyObject* obj,
                            paddle::framework::AttributeMap& attrs,  // NOLINT
127 128
                            const std::string& key,
                            const std::string& op_type,
129 130 131 132
                            ssize_t arg_pos);

void CastPyArg2AttrInts(PyObject* obj,
                        paddle::framework::AttributeMap& attrs,  // NOLINT
133 134
                        const std::string& key,
                        const std::string& op_type,
135 136 137 138
                        ssize_t arg_pos);

void CastPyArg2AttrLongs(PyObject* obj,
                         paddle::framework::AttributeMap& attrs,  // NOLINT
139 140
                         const std::string& key,
                         const std::string& op_type,
141 142 143 144
                         ssize_t arg_pos);

void CastPyArg2AttrFloats(PyObject* obj,
                          paddle::framework::AttributeMap& attrs,  // NOLINT
145 146
                          const std::string& key,
                          const std::string& op_type,
147 148 149 150
                          ssize_t arg_pos);

void CastPyArg2AttrFloat64s(PyObject* obj,
                            paddle::framework::AttributeMap& attrs,  // NOLINT
151 152
                            const std::string& key,
                            const std::string& op_type,
153 154 155 156
                            ssize_t arg_pos);

void CastPyArg2AttrStrings(PyObject* obj,
                           paddle::framework::AttributeMap& attrs,  // NOLINT
157 158
                           const std::string& key,
                           const std::string& op_type,
159 160 161 162
                           ssize_t arg_pos);

void CastPyArg2AttrBlock(PyObject* obj,
                         paddle::framework::AttributeMap& attrs,  // NOLINT
163 164
                         const std::string& key,
                         const std::string& op_type,
165 166 167
                         ssize_t arg_pos);

void ConstructAttrMapFromPyArgs(
168 169 170
    const std::string& op_type,
    PyObject* args,
    ssize_t attr_start,
171 172 173 174
    ssize_t attr_end,
    paddle::framework::AttributeMap& attrs);  // NOLINT

std::shared_ptr<imperative::VarBase> GetVarBaseFromArgs(
175 176 177 178 179
    const std::string& op_type,
    const std::string& arg_name,
    PyObject* args,
    ssize_t arg_idx,
    bool dispensable = false);
180 181

std::vector<std::shared_ptr<imperative::VarBase>> GetVarBaseListFromArgs(
182 183 184 185 186
    const std::string& op_type,
    const std::string& arg_name,
    PyObject* args,
    ssize_t arg_idx,
    bool dispensable = false);
187 188

unsigned long GetUnsignedLongFromArgs(  // NOLINT
189 190 191 192 193
    const std::string& op_type,
    const std::string& arg_name,
    PyObject* args,
    ssize_t arg_idx,
    bool dispensable = false);
194 195 196

void InitOpsAttrTypeMap();

197 198 199
ssize_t GetIdxFromCoreOpsInfoMap(
    const std::unordered_map<std::string, std::vector<std::string>>&
        core_ops_info_map,
200 201
    const std::string& op_type,
    const std::string& name);
202

203 204
}  // namespace pybind
}  // namespace paddle