eager_custom_python_api.h 2.3 KB
Newer Older
0
0x45f 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include <iostream>
17

18
#include "paddle/phi/core/enforce.h"
0
0x45f 已提交
19

20 21 22
namespace paddle {
namespace pybind {

23 24 25
static PyObject *eager_api_linear(PyObject *self,
                                  PyObject *args,
                                  PyObject *kwargs) {
0
0x45f 已提交
26 27
  PyThreadState *tstate = nullptr;
  try {
28 29 30
    auto x = GetTensorFromArgs("linear", "X", args, 0, false);
    auto weight = GetTensorFromArgs("linear", "weight", args, 1, false);
    auto bias = GetTensorFromArgs("linear", "Bias", args, 2, true);
0
0x45f 已提交
31
    tstate = PyEval_SaveThread();
32
    if (bias.initialized()) {
33 34
      auto mm_out = matmul_ad_func(x, weight, false, false);
      auto out = add_ad_func(mm_out, bias);
35 36 37 38
      PyEval_RestoreThread(tstate);
      tstate = nullptr;
      return ToPyObject(out);
    } else {
39
      auto mm_out = matmul_ad_func(x, weight, false, false);
40 41 42 43
      PyEval_RestoreThread(tstate);
      tstate = nullptr;
      return ToPyObject(mm_out);
    }
44 45 46 47 48 49
  } catch (paddle::platform::EnforceNotMet &exception) {
    if (tstate) {
      PyEval_RestoreThread(tstate);
    }
    std::ostringstream sout;
    sout << exception.what();
50
    sout << "  [operator < linear > error]";
51 52 53
    exception.set_error_str(sout.str());
    ThrowExceptionToPython(std::current_exception());
    return nullptr;
0
0x45f 已提交
54 55 56 57 58
  } catch (...) {
    if (tstate) {
      PyEval_RestoreThread(tstate);
    }
    ThrowExceptionToPython(std::current_exception());
59
    return nullptr;
0
0x45f 已提交
60 61 62
  }
}

63 64 65
static PyMethodDef CustomEagerFinalStateMethods[] = {
    {"linear",
     (PyCFunction)(void (*)(void))eager_api_linear,
0
0x45f 已提交
66 67 68
     METH_VARARGS | METH_KEYWORDS,
     "C++ interface function for run_program in dygraph."},
    {nullptr, nullptr, 0, nullptr}};
69 70 71

}  // namespace pybind
}  // namespace paddle