op_function_common.h 7.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17 18 19 20 21
#if defined(_MSC_VER)
#include <BaseTsd.h>
typedef SSIZE_T ssize_t;
#endif

22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>

#include <memory>
#include <string>
#include <vector>

#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/pybind/imperative.h"

namespace py = pybind11;
namespace paddle {
namespace pybind {

bool PyObject_CheckBool(PyObject** obj);

bool PyObject_CheckLongOrToLong(PyObject** obj);

bool PyObject_CheckFloatOrToFloat(PyObject** obj);

bool PyObject_CheckString(PyObject* obj);

51 52
bool CastPyArg2Boolean(PyObject* obj,
                       const std::string& op_type,
53 54
                       ssize_t arg_pos);
int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos);
55 56
int64_t CastPyArg2Long(PyObject* obj,
                       const std::string& op_type,
57
                       ssize_t arg_pos);
58 59
float CastPyArg2Float(PyObject* obj,
                      const std::string& op_type,
60
                      ssize_t arg_pos);
61 62
double CastPyArg2Double(PyObject* obj,
                        const std::string& op_type,
63
                        ssize_t arg_pos);
64 65
std::string CastPyArg2String(PyObject* obj,
                             const std::string& op_type,
66
                             ssize_t arg_pos);
67 68
std::vector<bool> CastPyArg2Booleans(PyObject* obj,
                                     const std::string& op_type,
69
                                     ssize_t arg_pos);
70 71
std::vector<int> CastPyArg2Ints(PyObject* obj,
                                const std::string& op_type,
72
                                ssize_t arg_pos);
73 74
std::vector<int64_t> CastPyArg2Longs(PyObject* obj,
                                     const std::string& op_type,
75
                                     ssize_t arg_pos);
76 77
std::vector<float> CastPyArg2Floats(PyObject* obj,
                                    const std::string& op_type,
78 79 80 81 82 83 84 85
                                    ssize_t arg_pos);
std::vector<double> CastPyArg2Float64s(PyObject* obj,
                                       const std::string& op_type,
                                       ssize_t arg_pos);
std::vector<std::string> CastPyArg2Strings(PyObject* obj,
                                           const std::string& op_type,
                                           ssize_t arg_pos);

86 87
void CastPyArg2AttrBoolean(PyObject* obj,
                           paddle::framework::AttributeMap& attrs,  // NOLINT
88 89
                           const std::string& key,
                           const std::string& op_type,
90 91 92 93
                           ssize_t arg_pos);

void CastPyArg2AttrInt(PyObject* obj,
                       paddle::framework::AttributeMap& attrs,  // NOLINT
94 95
                       const std::string& key,
                       const std::string& op_type,
96 97 98 99
                       ssize_t arg_pos);

void CastPyArg2AttrLong(PyObject* obj,
                        paddle::framework::AttributeMap& attrs,  // NOLINT
100 101
                        const std::string& key,
                        const std::string& op_type,
102 103 104 105
                        ssize_t arg_pos);

void CastPyArg2AttrFloat(PyObject* obj,
                         paddle::framework::AttributeMap& attrs,  // NOLINT
106 107
                         const std::string& key,
                         const std::string& op_type,
108 109
                         ssize_t arg_pos);

110 111 112 113 114 115
void CastPyArg2AttrDouble(PyObject* obj,
                          paddle::framework::AttributeMap& attrs,  // NOLINT
                          const std::string& key,
                          const std::string& op_type,
                          ssize_t arg_pos);

116 117
void CastPyArg2AttrString(PyObject* obj,
                          paddle::framework::AttributeMap& attrs,  // NOLINT
118 119
                          const std::string& key,
                          const std::string& op_type,
120 121 122 123
                          ssize_t arg_pos);

void CastPyArg2AttrBooleans(PyObject* obj,
                            paddle::framework::AttributeMap& attrs,  // NOLINT
124 125
                            const std::string& key,
                            const std::string& op_type,
126 127 128 129
                            ssize_t arg_pos);

void CastPyArg2AttrInts(PyObject* obj,
                        paddle::framework::AttributeMap& attrs,  // NOLINT
130 131
                        const std::string& key,
                        const std::string& op_type,
132 133 134 135
                        ssize_t arg_pos);

void CastPyArg2AttrLongs(PyObject* obj,
                         paddle::framework::AttributeMap& attrs,  // NOLINT
136 137
                         const std::string& key,
                         const std::string& op_type,
138 139 140 141
                         ssize_t arg_pos);

void CastPyArg2AttrFloats(PyObject* obj,
                          paddle::framework::AttributeMap& attrs,  // NOLINT
142 143
                          const std::string& key,
                          const std::string& op_type,
144 145 146 147
                          ssize_t arg_pos);

void CastPyArg2AttrFloat64s(PyObject* obj,
                            paddle::framework::AttributeMap& attrs,  // NOLINT
148 149
                            const std::string& key,
                            const std::string& op_type,
150 151 152 153
                            ssize_t arg_pos);

void CastPyArg2AttrStrings(PyObject* obj,
                           paddle::framework::AttributeMap& attrs,  // NOLINT
154 155
                           const std::string& key,
                           const std::string& op_type,
156 157 158 159
                           ssize_t arg_pos);

void CastPyArg2AttrBlock(PyObject* obj,
                         paddle::framework::AttributeMap& attrs,  // NOLINT
160 161
                         const std::string& key,
                         const std::string& op_type,
162 163 164
                         ssize_t arg_pos);

void ConstructAttrMapFromPyArgs(
165 166 167
    const std::string& op_type,
    PyObject* args,
    ssize_t attr_start,
168 169 170 171
    ssize_t attr_end,
    paddle::framework::AttributeMap& attrs);  // NOLINT

std::shared_ptr<imperative::VarBase> GetVarBaseFromArgs(
172 173 174 175 176
    const std::string& op_type,
    const std::string& arg_name,
    PyObject* args,
    ssize_t arg_idx,
    bool dispensable = false);
177 178

std::vector<std::shared_ptr<imperative::VarBase>> GetVarBaseListFromArgs(
179 180 181 182 183
    const std::string& op_type,
    const std::string& arg_name,
    PyObject* args,
    ssize_t arg_idx,
    bool dispensable = false);
184 185

unsigned long GetUnsignedLongFromArgs(  // NOLINT
186 187 188 189 190
    const std::string& op_type,
    const std::string& arg_name,
    PyObject* args,
    ssize_t arg_idx,
    bool dispensable = false);
191 192 193

void InitOpsAttrTypeMap();

194 195 196
ssize_t GetIdxFromCoreOpsInfoMap(
    const std::unordered_map<std::string, std::vector<std::string>>&
        core_ops_info_map,
197 198
    const std::string& op_type,
    const std::string& name);
199

200 201
}  // namespace pybind
}  // namespace paddle