op_function_common.h 7.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17 18 19 20 21
#if defined(_MSC_VER)
#include <BaseTsd.h>
typedef SSIZE_T ssize_t;
#endif

22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>

#include <memory>
#include <string>
#include <vector>

#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/pybind/imperative.h"

namespace py = pybind11;
namespace paddle {
namespace pybind {

bool PyObject_CheckBool(PyObject** obj);

bool PyObject_CheckLongOrToLong(PyObject** obj);

bool PyObject_CheckFloatOrToFloat(PyObject** obj);

bool PyObject_CheckString(PyObject* obj);

51 52
bool CastPyArg2Boolean(PyObject* obj,
                       const std::string& op_type,
53 54
                       ssize_t arg_pos);
int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos);
55 56
int64_t CastPyArg2Long(PyObject* obj,
                       const std::string& op_type,
57
                       ssize_t arg_pos);
58 59 60
float16 CastPyArg2Float16(PyObject* obj,
                          const std::string& op_type,
                          ssize_t arg_pos);
61 62
float CastPyArg2Float(PyObject* obj,
                      const std::string& op_type,
63
                      ssize_t arg_pos);
64 65
double CastPyArg2Double(PyObject* obj,
                        const std::string& op_type,
66
                        ssize_t arg_pos);
67 68 69
phi::dtype::complex<float> CastPyArg2Complex(PyObject* obj,
                                             const std::string& op_type,
                                             ssize_t arg_pos);
70 71
std::string CastPyArg2String(PyObject* obj,
                             const std::string& op_type,
72
                             ssize_t arg_pos);
73 74
std::vector<bool> CastPyArg2Booleans(PyObject* obj,
                                     const std::string& op_type,
75
                                     ssize_t arg_pos);
76 77
std::vector<int> CastPyArg2Ints(PyObject* obj,
                                const std::string& op_type,
78
                                ssize_t arg_pos);
79 80
std::vector<int64_t> CastPyArg2Longs(PyObject* obj,
                                     const std::string& op_type,
81
                                     ssize_t arg_pos);
82 83
std::vector<float> CastPyArg2Floats(PyObject* obj,
                                    const std::string& op_type,
84 85 86 87 88 89 90 91
                                    ssize_t arg_pos);
std::vector<double> CastPyArg2Float64s(PyObject* obj,
                                       const std::string& op_type,
                                       ssize_t arg_pos);
std::vector<std::string> CastPyArg2Strings(PyObject* obj,
                                           const std::string& op_type,
                                           ssize_t arg_pos);

92 93
void CastPyArg2AttrBoolean(PyObject* obj,
                           paddle::framework::AttributeMap& attrs,  // NOLINT
94 95
                           const std::string& key,
                           const std::string& op_type,
96 97 98 99
                           ssize_t arg_pos);

void CastPyArg2AttrInt(PyObject* obj,
                       paddle::framework::AttributeMap& attrs,  // NOLINT
100 101
                       const std::string& key,
                       const std::string& op_type,
102 103 104 105
                       ssize_t arg_pos);

void CastPyArg2AttrLong(PyObject* obj,
                        paddle::framework::AttributeMap& attrs,  // NOLINT
106 107
                        const std::string& key,
                        const std::string& op_type,
108 109 110 111
                        ssize_t arg_pos);

void CastPyArg2AttrFloat(PyObject* obj,
                         paddle::framework::AttributeMap& attrs,  // NOLINT
112 113
                         const std::string& key,
                         const std::string& op_type,
114 115
                         ssize_t arg_pos);

116 117 118 119 120 121
void CastPyArg2AttrDouble(PyObject* obj,
                          paddle::framework::AttributeMap& attrs,  // NOLINT
                          const std::string& key,
                          const std::string& op_type,
                          ssize_t arg_pos);

122 123
void CastPyArg2AttrString(PyObject* obj,
                          paddle::framework::AttributeMap& attrs,  // NOLINT
124 125
                          const std::string& key,
                          const std::string& op_type,
126 127 128 129
                          ssize_t arg_pos);

void CastPyArg2AttrBooleans(PyObject* obj,
                            paddle::framework::AttributeMap& attrs,  // NOLINT
130 131
                            const std::string& key,
                            const std::string& op_type,
132 133 134 135
                            ssize_t arg_pos);

void CastPyArg2AttrInts(PyObject* obj,
                        paddle::framework::AttributeMap& attrs,  // NOLINT
136 137
                        const std::string& key,
                        const std::string& op_type,
138 139 140 141
                        ssize_t arg_pos);

void CastPyArg2AttrLongs(PyObject* obj,
                         paddle::framework::AttributeMap& attrs,  // NOLINT
142 143
                         const std::string& key,
                         const std::string& op_type,
144 145 146 147
                         ssize_t arg_pos);

void CastPyArg2AttrFloats(PyObject* obj,
                          paddle::framework::AttributeMap& attrs,  // NOLINT
148 149
                          const std::string& key,
                          const std::string& op_type,
150 151 152 153
                          ssize_t arg_pos);

void CastPyArg2AttrFloat64s(PyObject* obj,
                            paddle::framework::AttributeMap& attrs,  // NOLINT
154 155
                            const std::string& key,
                            const std::string& op_type,
156 157 158 159
                            ssize_t arg_pos);

void CastPyArg2AttrStrings(PyObject* obj,
                           paddle::framework::AttributeMap& attrs,  // NOLINT
160 161
                           const std::string& key,
                           const std::string& op_type,
162 163 164 165
                           ssize_t arg_pos);

void CastPyArg2AttrBlock(PyObject* obj,
                         paddle::framework::AttributeMap& attrs,  // NOLINT
166 167
                         const std::string& key,
                         const std::string& op_type,
168 169 170
                         ssize_t arg_pos);

void ConstructAttrMapFromPyArgs(
171 172 173
    const std::string& op_type,
    PyObject* args,
    ssize_t attr_start,
174 175 176 177
    ssize_t attr_end,
    paddle::framework::AttributeMap& attrs);  // NOLINT

std::shared_ptr<imperative::VarBase> GetVarBaseFromArgs(
178 179 180 181 182
    const std::string& op_type,
    const std::string& arg_name,
    PyObject* args,
    ssize_t arg_idx,
    bool dispensable = false);
183 184

std::vector<std::shared_ptr<imperative::VarBase>> GetVarBaseListFromArgs(
185 186 187 188 189
    const std::string& op_type,
    const std::string& arg_name,
    PyObject* args,
    ssize_t arg_idx,
    bool dispensable = false);
190 191

unsigned long GetUnsignedLongFromArgs(  // NOLINT
192 193 194 195 196
    const std::string& op_type,
    const std::string& arg_name,
    PyObject* args,
    ssize_t arg_idx,
    bool dispensable = false);
197 198 199

void InitOpsAttrTypeMap();

200 201 202
ssize_t GetIdxFromCoreOpsInfoMap(
    const std::unordered_map<std::string, std::vector<std::string>>&
        core_ops_info_map,
203 204
    const std::string& op_type,
    const std::string& name);
205

206 207
}  // namespace pybind
}  // namespace paddle