ops_api.cc 4.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <pybind11/pybind11.h>

#include "paddle/fluid/pybind/static_op_function.h"
#include "paddle/phi/core/enforce.h"

namespace paddle {
namespace pybind {

23 24 25 26
static PyObject *add_n(PyObject *self, PyObject *args, PyObject *kwargs) {
  return static_api_add_n(self, args, kwargs);
}

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
static PyObject *mean(PyObject *self, PyObject *args, PyObject *kwargs) {
  return static_api_mean(self, args, kwargs);
}

static PyObject *sum(PyObject *self, PyObject *args, PyObject *kwargs) {
  return static_api_sum(self, args, kwargs);
}

static PyObject *full(PyObject *self, PyObject *args, PyObject *kwargs) {
  return static_api_full(self, args, kwargs);
}

static PyObject *divide(PyObject *self, PyObject *args, PyObject *kwargs) {
  return static_api_divide(self, args, kwargs);
}

43 44 45 46 47 48 49 50
static PyObject *data(PyObject *self, PyObject *args, PyObject *kwargs) {
  return static_api_data(self, args, kwargs);
}

static PyObject *fetch(PyObject *self, PyObject *args, PyObject *kwargs) {
  return static_api_fetch(self, args, kwargs);
}

51 52 53 54
static PyObject *concat(PyObject *self, PyObject *args, PyObject *kwargs) {
  return static_api_concat(self, args, kwargs);
}

55 56 57 58
static PyObject *split(PyObject *self, PyObject *args, PyObject *kwargs) {
  return static_api_split(self, args, kwargs);
}

59
static PyMethodDef OpsAPI[] = {{"add_n",
60 61 62 63
                                (PyCFunction)(void (*)(void))add_n,
                                METH_VARARGS | METH_KEYWORDS,
                                "C++ interface function for add_n."},
                               {"mean",
64 65 66 67 68 69 70 71 72 73 74
                                (PyCFunction)(void (*)(void))mean,
                                METH_VARARGS | METH_KEYWORDS,
                                "C++ interface function for mean."},
                               {"sum",
                                (PyCFunction)(void (*)(void))sum,
                                METH_VARARGS | METH_KEYWORDS,
                                "C++ interface function for sum."},
                               {"divide",
                                (PyCFunction)(void (*)(void))divide,
                                METH_VARARGS | METH_KEYWORDS,
                                "C++ interface function for divide."},
75 76 77 78
                               {"concat",
                                (PyCFunction)(void (*)(void))concat,
                                METH_VARARGS | METH_KEYWORDS,
                                "C++ interface function for concat."},
79 80 81 82
                               {"full",
                                (PyCFunction)(void (*)(void))full,
                                METH_VARARGS | METH_KEYWORDS,
                                "C++ interface function for full."},
83 84 85 86
                               {"split",
                                (PyCFunction)(void (*)(void))split,
                                METH_VARARGS | METH_KEYWORDS,
                                "C++ interface function for split."},
87 88 89 90 91 92 93 94
                               {"data",
                                (PyCFunction)(void (*)(void))data,
                                METH_VARARGS | METH_KEYWORDS,
                                "C++ interface function for data."},
                               {"fetch",
                                (PyCFunction)(void (*)(void))fetch,
                                METH_VARARGS | METH_KEYWORDS,
                                "C++ interface function for fetch."},
95 96 97 98 99 100 101 102 103 104
                               {nullptr, nullptr, 0, nullptr}};

void BindOpsAPI(pybind11::module *module) {
  if (PyModule_AddFunctions(module->ptr(), OpsAPI) < 0) {
    PADDLE_THROW(phi::errors::Fatal("Add C++ api to core.ops failed!"));
  }
}

}  // namespace pybind
}  // namespace paddle