op_function.h 9.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
21

22 23 24
#include <memory>
#include <string>
#include <vector>
25

26 27
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/op_info.h"
28
#include "paddle/fluid/framework/op_registry.h"
29 30 31
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
32
#include "paddle/fluid/pybind/exception.h"
33
#include "paddle/fluid/pybind/imperative.h"
34
#include "paddle/fluid/pybind/op_function_common.h"
35

36 37 38
namespace py = pybind11;
namespace paddle {
namespace pybind {
39 40

static inline std::shared_ptr<imperative::VarBase> CastPyHandleToVarBase(
41 42 43 44 45
    const std::string& op_type,
    const std::string& arg_name,
    int arg_idx,
    const py::handle& handle,
    bool dispensable = false) {
46 47
  PyObject* py_obj = handle.ptr();  // get underlying PyObject
  if (!py_obj || py_obj == Py_None) {
48 49 50 51
    if (!dispensable) {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "%s(): argument '%s' (position %d) must be Tensor, but got "
          "%s",
52 53 54 55
          op_type,
          arg_name,
          arg_idx,
          Py_TYPE(py_obj)->tp_name));
56
    }
57 58 59 60 61 62 63 64
    return nullptr;
  }
  try {
    return py::cast<std::shared_ptr<imperative::VarBase>>(py::handle(py_obj));
  } catch (py::cast_error&) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument '%s' (position %d) must be Tensor, but got "
        "%s",
65 66 67 68
        op_type,
        arg_name,
        arg_idx,
        Py_TYPE(py_obj)->tp_name));
69 70 71 72 73
  }
}

static inline std::vector<std::shared_ptr<imperative::VarBase>>
CastPyHandleToVarBaseList(const std::string& op_type,
74 75 76 77
                          const std::string& arg_name,
                          int arg_idx,
                          const py::handle& handle,
                          bool dispensable = false) {
78 79
  PyObject* py_obj = handle.ptr();  // get underlying PyObject
  if (!py_obj || py_obj == Py_None) {
80 81 82 83
    if (!dispensable) {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "%s(): argument '%s' (position %d) must be Tensor, but got "
          "%s",
84 85 86 87
          op_type,
          arg_name,
          arg_idx,
          Py_TYPE(py_obj)->tp_name));
88
    }
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    return {};
  }
  std::vector<std::shared_ptr<imperative::VarBase>> result;
  if (PyList_Check(py_obj) || PyTuple_Check(py_obj)) {
    auto size = PyTuple_Check(py_obj) ? PyTuple_GET_SIZE(py_obj)
                                      : PyList_GET_SIZE(py_obj);
    for (auto i = 0; i < size; ++i) {
      PyObject* item = PyTuple_Check(py_obj) ? PyTuple_GET_ITEM(py_obj, i)
                                             : PyList_GET_ITEM(py_obj, i);
      if (!item || item == Py_None) {
        result.emplace_back(nullptr);
        continue;
      }
      try {
        result.emplace_back(
            py::cast<std::shared_ptr<imperative::VarBase>>(py::handle(item)));
      } catch (py::cast_error&) {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "%s(): argument '%s' (position %d) must be list of "
            "Tensors, but "
            "got %s in list (item %d)",
110 111 112 113 114
            op_type,
            arg_name,
            arg_idx,
            Py_TYPE(item)->tp_name,
            i));
115 116 117 118 119 120
      }
    }
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s(): argument '%s' (position %d) must be list of Tensors, but got "
        "%s",
121 122 123 124
        op_type,
        arg_name,
        arg_idx,
        Py_TYPE(py_obj)->tp_name));
125 126 127 128
  }
  return result;
}  // namespace pybind

129 130 131
static inline void ConstructAttrMapFromPyArgs(const std::string& op_type,
                                              int start_idx,
                                              framework::AttributeMap* attrs,
132 133
                                              const py::args& args) {
  PADDLE_ENFORCE_EQ(
134 135
      args.size() % 2,
      0,
136 137 138
      platform::errors::InvalidArgument(
          "The number of arguments for arributes should be even."));
  for (size_t i = 0; i < args.size(); i += 2) {
139 140 141 142 143 144 145 146 147
    std::string name;
    framework::Attribute value;
    try {
      name = args[i].cast<std::string>();
    } catch (std::exception& e) {
      PyObject* py_obj = args[i].ptr();  // get underlying PyObject
      PADDLE_THROW(platform::errors::InvalidArgument(
          "%s(): argument (position %d) must be str, but got "
          "%s",
148 149 150
          op_type,
          start_idx + i,
          Py_TYPE(py_obj)->tp_name));
151 152 153 154 155 156 157 158 159
    }
    try {
      value = args[i + 1].cast<framework::Attribute>();
    } catch (std::exception& e) {
      PyObject* py_obj = args[i + 1].ptr();  // get underlying PyObject
      PADDLE_THROW(platform::errors::InvalidArgument(
          "%s(): argument (position %d) must be "
          "Attribute type (one of str, bool, int, int64, float, or list of "
          "them), but got %s",
160 161 162
          op_type,
          start_idx + i + 1,
          Py_TYPE(py_obj)->tp_name));
163
    }
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
    (*attrs)[name] = value;
  }
}

static inline std::vector<std::shared_ptr<imperative::VarBase>>
ConstructDuplicableOutput(const size_t num) {
  auto tracer = imperative::GetCurrentTracer();
  std::vector<std::shared_ptr<imperative::VarBase>> res;
  res.reserve(num);
  for (size_t i = 0; i < num; i++) {
    auto var_base_name = tracer->GenerateUniqueName();
    res.emplace_back(new imperative::VarBase(var_base_name));
  }
  return res;
}
179 180 181 182 183

static inline void HandleViewBetweenInputAndOutput(
    const std::shared_ptr<imperative::VarBase>& input_var,
    const std::shared_ptr<imperative::VarBase>& view_output_var) {
  PADDLE_ENFORCE_EQ(
184 185
      input_var->Var().IsInitialized(),
      true,
186 187 188 189 190 191
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
                                        input_var->Name()));

  if (input_var->Var().IsType<framework::LoDTensor>()) {
    const auto& input_tensor = input_var->Var().Get<framework::LoDTensor>();
    PADDLE_ENFORCE_EQ(
192 193
        input_tensor.IsInitialized(),
        true,
194 195 196 197 198
        platform::errors::InvalidArgument(
            "LoDTensor %s has not been initialized!", input_var->Name()));

    auto* view_output_tensor =
        view_output_var->MutableVar()->GetMutable<framework::LoDTensor>();
199
    view_output_tensor->ShareBufferWith(input_tensor);
200 201 202 203 204 205 206
    view_output_tensor->ShareInplaceVersionCounterWith(input_tensor);

    VLOG(3) << "Perform View between Output Var(" << view_output_var->Name()
            << ") and Input Var(" << input_var->Name()
            << "), share allocation and inplace version.";
  }
}
207

208
static inline PyObject* MakeReturnPyObject(
209 210 211 212 213 214 215 216
    const std::shared_ptr<paddle::imperative::VarBase>& out) {
  return ::pybind11::detail::type_caster_base<imperative::VarBase>::cast_holder(
             ::pybind11::detail::holder_helper<
                 std::shared_ptr<imperative::VarBase>>::get(out),
             &out)
      .ptr();
}

217
static inline PyObject* MakeReturnPyObject(
218 219 220 221 222
    const std::vector<std::shared_ptr<imperative::VarBase>>& out) {
  PyObject* result = PyList_New((Py_ssize_t)out.size());

  for (size_t i = 0; i < out.size(); i++) {
    PyList_SET_ITEM(
223 224
        result,
        (Py_ssize_t)i,
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
        ::pybind11::detail::type_caster_base<imperative::VarBase>::cast_holder(
            ::pybind11::detail::holder_helper<
                std::shared_ptr<imperative::VarBase>>::get(out[i]),
            &out[i])
            .ptr());  // NOLINT
  }

  return result;
}

template <typename Tuple, size_t N>
struct TupleVarBasesResult {
  static void Run(const Tuple& out, PyObject* result) {
    TupleVarBasesResult<Tuple, N - 1>::Run(out, result);
    PyTuple_SET_ITEM(result, N - 1, MakeReturnPyObject(std::get<N - 1>(out)));
  }
};

template <typename Tuple>
struct TupleVarBasesResult<Tuple, 1> {
  static void Run(const Tuple& out, PyObject* result) {
    PyTuple_SET_ITEM(result, 0, MakeReturnPyObject(std::get<0>(out)));
  }
};

template <typename... Args>
251
PyObject* MakeReturnPyObject(const std::tuple<Args...>& out) {
252 253 254 255 256 257 258 259
  auto len = sizeof...(Args);
  PyObject* result = PyTuple_New(len);

  TupleVarBasesResult<decltype(out), sizeof...(Args)>::Run(out, result);

  return result;
}

260 261 262 263 264 265 266 267
void BindOpFunctions1(pybind11::module* module);
void BindOpFunctions2(pybind11::module* module);
void BindOpFunctions3(pybind11::module* module);
void BindOpFunctions4(pybind11::module* module);
void BindOpFunctions5(pybind11::module* module);
void BindOpFunctions6(pybind11::module* module);
void BindOpFunctions7(pybind11::module* module);
void BindOpFunctions8(pybind11::module* module);
268

269 270
}  // namespace pybind
}  // namespace paddle