eager_legacy_op_function_generator.cc 20.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <fstream>
#include <iostream>
#include <set>
#include <string>
0
0x45f 已提交
20
#include <unordered_set>
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
#ifndef _WIN32
#include <unistd.h>
#endif

#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h"
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
#endif
#include "paddle/fluid/pybind/op_function_generator.h"

36
// phi
37
#include "paddle/phi/kernels/declarations.h"
38

J
Jiabin Yang 已提交
39 40 41 42 43
static std::string LegalizeVarName(const std::string& var_name) {
  std::string ret = var_name;
  std::replace(ret.begin(), ret.end(), '@', '_');  // replace all '-' to '_'
  return ret;
}
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
// clang-format off
const char* OUT_INITIALIZER_TEMPLATE =
    R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase("auto_"+std::to_string(VarBaseUniqueNameID++)+"_"))}})";
const char* OUT_DUPLICABLE_INITIALIZER_TEMPLATE = R"({"%s", ConstructDuplicableOutput(%s)})";

const char* INPUT_INITIALIZER_TEMPLATE = R"({"%s", {%s}})";
const char* INPUT_LIST_INITIALIZER_TEMPLATE = R"({"%s", %s})";

const char* INPUT_INITIALIZER_TEMPLATE_WITH_NULL = R"(
    if (%s != nullptr) {
      ins["%s"] = {%s};
    }
)";

const char* INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"(
    if (%s.size() != 0) {
      ins["%s"] = %s;
    }
)";

const char* OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL = R"(
    outs["%s"] = {%s};
)";

const char* OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"(
    outs["%s"] = %s;
)";
// if inputs is list, no need {}
const char* ARG_OUT_NUM = R"(%sNum)";
const char* ARG_OUT_NUM_TYPE = R"(size_t )";

const char* IN_VAR_TYPE = R"(py::handle)";
const char* IN_VAR_LIST_TYPE = R"(py::handle)";

const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";

const char* CAST_VAR_TEMPLATE = R"(
82
    auto& %s = GetTensorFromArgs("%s", "%s", args, %d, %s);)";
83 84

const char* CAST_VAR_LIST_TEMPLATE = R"(
85
    auto %s = GetTensorListFromArgs("%s", "%s", args, %d, %s);)";
86

87
const char* CAST_VAR_PTR_TEMPLATE = R"(
88
    auto %s = GetTensorPtrFromArgs("%s", "%s", args, %d, %s);)";
89 90

const char* CAST_VAR_PTR_LIST_TEMPLATE = R"(
91
    auto %s = GetTensorPtrListFromArgs("%s", "%s", args, %d, %s);)";
92

93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
const char* CAST_SIZE_T_TEMPLATE = R"(
    auto %s = GetUnsignedLongFromArgs("%s", "%s", args, %d, %s);)";

const char* ARG_TEMPLATE = R"(const %s& %s)";

const char* RETURN_TUPLE_TYPE = R"(std::tuple<%s>)";
const char* RETURN_TUPLE_TEMPLATE = R"(std::make_tuple(%s))";
const char* RETURN_LIST_TEMPLATE = R"(outs["%s"])";
const char* RETURN_TEMPLATE = R"(outs["%s"][0])";

const char* FUNCTION_ARGS = R"(%s, const py::args& args)";
const char* FUNCTION_ARGS_NO_INPUT = R"(const py::args& args)";

const char* HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT = R"(
    if (ins.count("%s") && outs.count("%s")) {
      HandleViewBetweenInputAndOutput(ins["%s"][0], outs["%s"][0]);
    })";

const char* OP_FUNCTION_TEMPLATE =
R"(
static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs)
{
  PyThreadState *tstate = nullptr;
W
wanghuancoder 已提交
116
  try {
117 118 119 120 121 122 123 124
    %s
    framework::AttributeMap attrs;
    ConstructAttrMapFromPyArgs("%s", args, %d, PyTuple_GET_SIZE(args) , attrs);
    tstate = PyEval_SaveThread();
    %s
    PyEval_RestoreThread(tstate);
    tstate = nullptr;
    %s
W
wanghuancoder 已提交
125
  } catch(...) {
126 127 128 129 130 131 132 133 134 135
    if (tstate) {
      PyEval_RestoreThread(tstate);
    }
    ThrowExceptionToPython(std::current_exception());
    return nullptr;
  }
})";

const char* PYBIND_ITEM_TEMPLATE = R"(  {"%s", (PyCFunction)(void(*)(void))%s, METH_VARARGS | METH_KEYWORDS, "C++ interface function for %s in dygraph."},)";

0
0x45f 已提交
136 137 138 139
// These operators will skip automatical code generatrion and
// need to be handwritten in CUSTOM_HANDWRITE_OP_FUNC_FILE
std::unordered_set<std::string> CUSTOM_HANDWRITE_OPS_SET = {"run_program"};

140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
// clang-format on
static inline bool FindInsMap(const std::string& op_type,
                              const std::string& in_name) {
  return op_ins_map[op_type].count(in_name);
}

static inline bool FindOutsMap(const std::string& op_type,
                               const std::string& out_name) {
  return op_outs_map[op_type].count(out_name);
}

static inline bool FindPassingOutsMap(const std::string& op_type,
                                      const std::string& out_name) {
  return op_passing_outs_map[op_type].count(out_name);
}

static inline bool FindViewOpMap(const std::string& op_type) {
  return view_op_map.count(op_type);
}

static inline std::string TempName(const std::string& name) {
  return name + '_';
}

std::string GenerateOpFunctionsBody(
165 166
    const paddle::framework::proto::OpProto* op_proto,
    std::string func_name,
167 168 169
    std::map<std::string, std::string> inplace_map = {}) {
  auto& op_type = op_proto->type();
  std::string input_args = "";
170
  std::string call_api_str = "";
171 172 173 174 175 176
  std::string ins_initializer_with_null = "";
  std::string py_arg = "";
  int arg_idx = 0;
  int input_args_num = 0;
  std::string ins_cast_str = "";
  std::string view_strategy_str = "";
177 178 179 180 181 182
  if (!inplace_map.empty()) {
    // change call_api_str for inplace op
    call_api_str = "auto out = " + op_type + "__dygraph_function(";
  } else {
    call_api_str = "auto out = " + op_type + "_dygraph_function(";
  }
183 184 185 186 187 188 189
  for (auto& input : op_proto->inputs()) {
    auto& in_name = input.name();
    // skip those dispensable inputs, like ResidualData in conv2d
    if (input.dispensable() && !FindInsMap(op_type, in_name)) {
      continue;
    }
    const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE;
J
Jiabin Yang 已提交
190 191
    auto input_arg = paddle::string::Sprintf(
        ARG_TEMPLATE, in_type, TempName(LegalizeVarName(in_name)));
192 193 194 195 196 197
    input_args += input_arg;
    input_args += ",";
    input_args_num++;
    const auto in_cast_type =
        input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
    auto dispensable = input.dispensable() ? "true" : "false";
198 199 200 201 202 203
    ins_cast_str += paddle::string::Sprintf(in_cast_type,
                                            LegalizeVarName(in_name),
                                            op_type,
                                            in_name,
                                            arg_idx++,
                                            dispensable);
204

J
Jiabin Yang 已提交
205
    call_api_str += LegalizeVarName(in_name) + ", ";
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
  }

  if (!input_args.empty() && input_args.back() == ',') {
    input_args.pop_back();
  }

  // Generate outs initializer
  std::string outs_initializer = "{";
  std::string outs_initializer_with_null = "";
  std::string return_str = "";

  int outs_num = 0;
  for (auto& output : op_proto->outputs()) {
    auto& out_name = output.name();

    // skip those dispensable oututs
    if (output.dispensable() && !FindOutsMap(op_type, out_name)) {
      continue;
    }
    const auto out_type =
        output.duplicable() ? OUT_VAR_LIST_TYPE : OUT_VAR_TYPE;

    if (FindPassingOutsMap(op_type, out_name)) {
      if (input_args != "") {
        input_args += ",";
      }
      input_args += out_type;
J
Jiabin Yang 已提交
233
      input_args += LegalizeVarName(out_name);
234 235 236 237 238 239 240 241 242 243 244 245
      input_args_num++;

      if (output.dispensable()) {
        const auto out_template =
            output.duplicable() ? OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
                                : OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL;
        outs_initializer_with_null +=
            paddle::string::Sprintf(out_template, out_name, out_name);
      } else {
        const auto out_template = output.duplicable()
                                      ? INPUT_LIST_INITIALIZER_TEMPLATE
                                      : INPUT_INITIALIZER_TEMPLATE;
246 247
        outs_initializer += paddle::string::Sprintf(
            out_template, out_name, LegalizeVarName(out_name));
248 249 250
        outs_initializer += ",";
      }

251 252
      const auto in_cast_type = output.duplicable() ? CAST_VAR_PTR_LIST_TEMPLATE
                                                    : CAST_VAR_PTR_TEMPLATE;
253
      auto dispensable = output.dispensable() ? "true" : "false";
254 255 256 257 258 259
      ins_cast_str += paddle::string::Sprintf(in_cast_type,
                                              LegalizeVarName(out_name),
                                              op_type,
                                              out_name,
                                              arg_idx++,
                                              dispensable);
W
wanghuancoder 已提交
260

J
Jiabin Yang 已提交
261
      call_api_str += LegalizeVarName(out_name) + ", ";
262 263 264 265 266 267 268 269
    } else {
      // There are few Operators that have duplicable output, like `Out` in
      // split op. We need to specify the number of variables for the
      // duplicable output, as the argument OutNum;
      if (output.duplicable()) {
        if (input_args != "") {
          input_args += ",";
        }
J
Jiabin Yang 已提交
270 271
        auto out_num_str =
            paddle::string::Sprintf(ARG_OUT_NUM, LegalizeVarName(out_name));
272 273 274 275 276 277 278
        input_args += ARG_OUT_NUM_TYPE;
        input_args += out_num_str;
        input_args_num++;
        outs_initializer += paddle::string::Sprintf(
            OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str);

        auto dispensable = output.dispensable() ? "true" : "false";
279 280 281 282 283 284
        ins_cast_str += paddle::string::Sprintf(CAST_SIZE_T_TEMPLATE,
                                                out_num_str,
                                                op_type,
                                                out_num_str,
                                                arg_idx++,
                                                dispensable);
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
        call_api_str += out_num_str + ", ";
      } else {
        outs_initializer +=
            paddle::string::Sprintf(OUT_INITIALIZER_TEMPLATE, out_name);
      }
      outs_initializer += ",";
    }

    // return_str += paddle::string::Sprintf(return_template, out_name);
    // return_str += ",";
    outs_num += 1;
  }
  call_api_str += "attrs);";
  if (outs_initializer.back() == ',') {
    outs_initializer.pop_back();
    // return_str.pop_back();
  }
  outs_initializer += "}";
  if (FindViewOpMap(op_type)) {
    std::string viwe_input_name = view_op_map[op_type].first;
    std::string viwe_output_name = view_op_map[op_type].second;
306 307 308 309 310 311
    view_strategy_str +=
        paddle::string::Sprintf(HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT,
                                viwe_input_name,
                                viwe_output_name,
                                viwe_input_name,
                                viwe_output_name);
312
  }
313 314
  if (!inplace_map.empty()) {
    // For inplace op, Use the input PyObject directly.
315
    return_str = "std::map<ssize_t, ssize_t> inplace_var_idx_map;\n";
316 317 318 319 320
    for (auto& inplace_pair : inplace_map) {
      // Find index of inplace tensor, and directly use input PyObject.
      std::string inplace_arg_name = inplace_pair.second;
      std::string inplace_return_name = inplace_pair.first;
      const char* RETURN_INPLACE_TENSOR_TEMPLATE =
321 322
          "    ssize_t arg_id = "
          "GetIdxFromCoreOpsInfoMap(core_ops_legacy_args_info, "
323 324
          "\"%s\", \"%s\");\n"
          "    ssize_t return_id = "
325 326
          "GetIdxFromCoreOpsInfoMap(core_ops_legacy_returns_info, \"%s\", "
          "\"%s\");\n"
327 328 329 330 331 332
          "    inplace_var_idx_map[return_id] = arg_id;";
      return_str += paddle::string::Sprintf(RETURN_INPLACE_TENSOR_TEMPLATE,
                                            op_type,
                                            inplace_arg_name,
                                            op_type,
                                            inplace_return_name);
333
    }
334
    return_str += "    return ToPyObject(out, args, inplace_var_idx_map);";
335 336 337
  } else {
    return_str = "return ToPyObject(out);";
  }
W
wanghuancoder 已提交
338

339 340 341 342 343 344 345 346
  std::string function_args = "";
  if (input_args == "") {
    function_args = FUNCTION_ARGS_NO_INPUT;
  } else {
    function_args = paddle::string::Sprintf(FUNCTION_ARGS, input_args);
  }

  // generate op funtcion body
347 348 349 350 351 352 353
  auto op_function_str = paddle::string::Sprintf(OP_FUNCTION_TEMPLATE,
                                                 func_name,
                                                 ins_cast_str,
                                                 op_type,
                                                 input_args_num,
                                                 call_api_str,
                                                 return_str);
354 355 356 357

  return op_function_str;
}

358 359 360 361
static std::string GenerateCoreOpsInfoMap() {
  std::string result =
      "static PyObject * eager_get_core_ops_args_info(PyObject *self) {\n"
      "  PyThreadState *tstate = nullptr;\n"
W
wanghuancoder 已提交
362
      "  try {\n"
363
      "    return ToPyObject(core_ops_legacy_args_info);\n"
W
wanghuancoder 已提交
364
      "  } catch(...) {\n"
365 366 367 368 369 370 371 372
      "    if (tstate) {\n"
      "      PyEval_RestoreThread(tstate);\n"
      "    }\n"
      "    ThrowExceptionToPython(std::current_exception());\n"
      "    return nullptr;\n"
      "  }\n"
      "}\n"
      "\n"
373 374
      "static PyObject * eager_get_core_ops_args_type_info(PyObject *self) {\n"
      "  PyThreadState *tstate = nullptr;\n"
W
wanghuancoder 已提交
375
      "  try {\n"
376
      "    return ToPyObject(core_ops_legacy_args_type_info);\n"
W
wanghuancoder 已提交
377
      "  } catch(...) {\n"
378 379 380 381 382 383 384 385
      "    if (tstate) {\n"
      "      PyEval_RestoreThread(tstate);\n"
      "    }\n"
      "    ThrowExceptionToPython(std::current_exception());\n"
      "    return nullptr;\n"
      "  }\n"
      "}\n"
      "\n"
386 387
      "static PyObject * eager_get_core_ops_returns_info(PyObject *self) {\n"
      "  PyThreadState *tstate = nullptr;\n"
W
wanghuancoder 已提交
388
      "  try {\n"
389
      "    return ToPyObject(core_ops_legacy_returns_info);\n"
W
wanghuancoder 已提交
390
      "  } catch(...) {\n"
391 392 393 394 395 396 397 398 399 400 401
      "    if (tstate) {\n"
      "      PyEval_RestoreThread(tstate);\n"
      "    }\n"
      "    ThrowExceptionToPython(std::current_exception());\n"
      "    return nullptr;\n"
      "  }\n"
      "}\n";

  return result;
}

402 403 404 405 406 407 408 409 410 411 412 413 414
static std::tuple<std::vector<std::string>, std::vector<std::string>>
GenerateOpFunctions() {
  auto& op_info_map = paddle::framework::OpInfoMap::Instance().map();

  std::vector<std::string> op_function_list, bind_function_list;
  auto& all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels();
  for (auto& pair : op_info_map) {
    auto& op_info = pair.second;
    auto op_proto = op_info.proto_;
    if (op_proto == nullptr) {
      continue;
    }
    auto& op_type = op_proto->type();
0
0x45f 已提交
415 416 417 418
    // Skip operators that will be handwriten in CUSTOM_HANDWRITE_OP_FUNC_FILE.
    if (CUSTOM_HANDWRITE_OPS_SET.count(op_type)) {
      continue;
    }
419 420 421 422 423
    // Skip the sparse op
    if (op_type.compare(0, 7, "sparse_") == 0 && op_type != "sparse_momentum" &&
        op_type != "sparse_attention") {
      continue;
    }
0
0x45f 已提交
424
    // Skip operator which is not inherit form OperatorWithKernel, like while,
425
    // since only OperatorWithKernel can run in dygraph mode.
426
    // if the phi lib contains op kernel, we still generate ops method
427
    if (!all_kernels.count(op_type) &&
428
        !phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type)) {
429 430
      continue;
    }
431
    std::string func_name = "eager_legacy_api_" + op_type;
432 433
    std::string op_function_str =
        GenerateOpFunctionsBody(op_proto, func_name, {});
434 435 436 437 438 439 440

    // generate pybind item
    auto bind_function_str = paddle::string::Sprintf(
        PYBIND_ITEM_TEMPLATE, op_type, func_name, op_type);

    op_function_list.emplace_back(std::move(op_function_str));
    bind_function_list.emplace_back(std::move(bind_function_str));
441 442 443 444 445 446 447 448 449 450 451 452 453

    // NOTE(pangyoki): Inplace Strategy.
    // In this case, output will reuse input varbase.
    // Dygraph mode needs to be aligned with the in-place strategy in static
    // mode, and the mapping relationships between output and input that have
    // been defined in static mode should be used in dygraph mode.
    // Find which ops need to use Inplace strategy in static mode, and get the
    // mapping relationship between Inplace output and input.
    auto& infer_inplace =
        paddle::framework::OpInfoMap::Instance().Get(op_type).infer_inplace_;
    std::map<std::string, std::string> inplace_map;
    // `sum` op has duplicate input. Don't consider adding inplace strategy
    // for `sum` in temporary.
454
    if (infer_inplace && !special_inplace_op_set.count(op_type)) {
455 456 457 458 459 460 461 462
      // Inplace OP: op_type_.
      // The inplace OP needs a new implementation method.
      auto in_to_outs = infer_inplace(true);
      for (auto& inplace_pair : in_to_outs) {
        inplace_map[inplace_pair.second] = inplace_pair.first;
      }

      std::string inplace_op_type = op_type + "_";
463
      std::string inplace_func_name = "eager_legacy_api_" + inplace_op_type;
464 465 466 467 468
      std::string inplace_op_function_str =
          GenerateOpFunctionsBody(op_proto, inplace_func_name, inplace_map);

      // generate pybind item
      auto inplace_bind_function_str =
469 470 471 472
          paddle::string::Sprintf(PYBIND_ITEM_TEMPLATE,
                                  inplace_op_type,
                                  inplace_func_name,
                                  inplace_op_type);
473 474 475 476

      op_function_list.emplace_back(std::move(inplace_op_function_str));
      bind_function_list.emplace_back(std::move(inplace_bind_function_str));
    }
477
  }
478

479 480 481 482
  return std::make_tuple(op_function_list, bind_function_list);
}

int main(int argc, char* argv[]) {
483 484
  if (argc != 2) {
    std::cerr << "argc must be 2" << std::endl;
485 486 487 488 489 490 491 492 493
    return -1;
  }

#ifdef PADDLE_WITH_ASCEND_CL
  auto ascend_ptr = paddle::framework::AscendInstance::GetInstance();
  ascend_ptr->InitGEForUT();
#endif

  std::vector<std::string> headers{
494 495
      "<Python.h>",
      "\"paddle/fluid/platform/enforce.h\"",
496 497
      "\"paddle/fluid/eager/api/generated/fluid_generated/"
      "dygraph_forward_api.h\"",
498 499
      "\"paddle/fluid/pybind/eager_utils.h\"",
      "\"paddle/fluid/platform/profiler/event_tracing.h\"",
500
      "\"paddle/fluid/pybind/exception.h\"",
501
      "\"paddle/fluid/pybind/op_function_common.h\"",
502
      "\"paddle/fluid/pybind/eager_legacy_custom_python_api.h\"",
503
      "\"paddle/fluid/pybind/eager.h\""};
504 505 506 507 508 509 510 511 512 513

  std::ofstream out(argv[1], std::ios::out);

  for (auto& header : headers) {
    out << "#include  " + header + "\n";
  }

  out << "\n\n";

  auto op_funcs = GenerateOpFunctions();
514 515
  auto core_ops_infos = GenerateCoreOpsInfoMap();
  std::string core_ops_infos_registry =
W
wanghuancoder 已提交
516
      "  {\"get_core_ops_args_info\", "
517 518
      "(PyCFunction)(void(*)(void))eager_get_core_ops_args_info, METH_NOARGS, "
      "\"C++ interface function for eager_get_core_ops_args_info.\"},\n"
W
wanghuancoder 已提交
519
      "  {\"get_core_ops_args_type_info\", "
520 521 522
      "(PyCFunction)(void(*)(void))eager_get_core_ops_args_type_info, "
      "METH_NOARGS, "
      "\"C++ interface function for eager_get_core_ops_args_type_info.\"},\n"
523 524 525 526
      "  {\"get_core_ops_returns_info\", "
      "(PyCFunction)(void(*)(void))eager_get_core_ops_returns_info, "
      "METH_NOARGS, \"C++ interface function for "
      "eager_get_core_ops_returns_info.\"},\n";
527 528 529

  out << "namespace paddle {\n"
      << "namespace pybind {\n\n";
530
  out << core_ops_infos;
531 532 533 534
  out << paddle::string::join_strings(std::get<0>(op_funcs), '\n');
  out << "\n\n";

  out << "static PyMethodDef ExtestMethods[] = {\n"
535 536
      << paddle::string::join_strings(std::get<1>(op_funcs), '\n') << "\n"
      << core_ops_infos_registry << "\n  {nullptr,nullptr,0,nullptr}"
537 538
      << "};\n\n";

539
  out << "void BindEagerOpFunctions(pybind11::module *module) {\n"
540
      << "  InitOpsAttrTypeMap();\n"
541
      << "  auto m = module->def_submodule(\"ops\");\n"
542 543
      << "  auto legacy = m.def_submodule(\"legacy\");\n"
      << "  if (PyModule_AddFunctions(legacy.ptr(), ExtestMethods) < 0) {\n"
544 545 546
      << "    PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
         "core.eager.ops failed!\"));\n"
      << "  }\n\n"
547
      << "  if (PyModule_AddFunctions(legacy.ptr(), CustomEagerMethods) < "
0
0x45f 已提交
548 549 550 551
         "0) {\n"
      << "    PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
         "core.eager.ops failed!\"));\n"
      << "  }\n\n"
552

W
wanghuancoder 已提交
553
      << "  BindFinalStateEagerOpFunctions(&m);\n"
554 555 556 557 558 559 560 561 562 563 564 565
      << "}\n\n"
      << "} // namespace pybind\n"
      << "} // namespace paddle\n";

  out.close();

#ifdef PADDLE_WITH_ASCEND_CL
  ge::GEFinalize();
#endif

  return 0;
}