op_function_generator.cc 19.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/pybind/op_function_generator.h"

17
#include <algorithm>
18 19 20
#include <fstream>
#include <iostream>
#include <string>
21 22 23
#ifndef _WIN32
#include <unistd.h>
#endif
24 25 26 27 28 29 30

#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h"
31
#ifdef PADDLE_WITH_ASCEND_CL
32 33
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
#endif
34

35
// phi
36
#include "paddle/phi/kernels/declarations.h"
37

J
Jiabin Yang 已提交
38 39 40 41 42 43
static std::string LegalizeVarName(const std::string& var_name) {
  std::string ret = var_name;
  std::replace(ret.begin(), ret.end(), '@', '_');  // replace all '-' to '_'
  return ret;
}

44 45 46 47 48 49 50 51
// NOTE(pangyoki): Inplace OP with duplicable input.
// The set includes inplace ops that have duplicable input.
// The first Varbase in input needs to be specified for the inplace strategy
// and share Varbase with the output.
std::set<std::string> inplace_op_duplicable_ins_set = {
    "sum",
};

52
// clang-format off
53
const char* OUT_INITIALIZER_TEMPLATE =
54
    R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase("auto_"+std::to_string(VarBaseUniqueNameID++)+"_"))}})";
55 56 57 58
const char* OUT_DUPLICABLE_INITIALIZER_TEMPLATE = R"({"%s", ConstructDuplicableOutput(%s)})";

const char* INPUT_INITIALIZER_TEMPLATE = R"({"%s", {%s}})";
const char* INPUT_LIST_INITIALIZER_TEMPLATE = R"({"%s", %s})";
L
Leo Chen 已提交
59

60 61 62 63
const char* INPUT_INITIALIZER_TEMPLATE_WITH_NULL = R"(
    if (%s != nullptr) {
      ins["%s"] = {%s};
    }
64
)";
L
Leo Chen 已提交
65

66
const char* INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"(
L
Leo Chen 已提交
67
    if (%s.size() != 0) {
68 69
      ins["%s"] = %s;
    }
L
Leo Chen 已提交
70 71
)";

72 73
const char* OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL = R"(
    outs["%s"] = {%s};
74 75
)";

76 77
const char* OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"(
    outs["%s"] = %s;
L
Leo Chen 已提交
78
)";
79 80 81 82
// if inputs is list, no need {}
const char* ARG_OUT_NUM = R"(%sNum)";
const char* ARG_OUT_NUM_TYPE = R"(size_t )";

83 84 85 86 87 88 89
const char* IN_VAR_TYPE = R"(py::handle)";
const char* IN_VAR_LIST_TYPE = R"(py::handle)";

const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";

const char* CAST_VAR_TEMPLATE = R"(
Z
zyfncg 已提交
90
    auto %s = GetVarBaseFromArgs(op_type, "%s", args, %d, %s);)";
91 92

const char* CAST_VAR_LIST_TEMPLATE = R"(
Z
zyfncg 已提交
93
    auto %s = GetVarBaseListFromArgs(op_type, "%s", args, %d, %s);)";
94

95
const char* CAST_SIZE_T_TEMPLATE = R"(
Z
zyfncg 已提交
96
    auto %s = GetUnsignedLongFromArgs(op_type, "%s", args, %d, %s);)";
97

98 99 100 101 102 103 104 105 106
const char* ARG_TEMPLATE = R"(const %s& %s)";

const char* RETURN_TUPLE_TYPE = R"(std::tuple<%s>)";
const char* RETURN_TUPLE_TEMPLATE = R"(std::make_tuple(%s))";
const char* RETURN_LIST_TEMPLATE = R"(outs["%s"])";
const char* RETURN_TEMPLATE = R"(outs["%s"][0])";

const char* FUNCTION_ARGS = R"(%s, const py::args& args)";
const char* FUNCTION_ARGS_NO_INPUT = R"(const py::args& args)";
107

108
const char* HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT = R"(
109 110 111 112
    if (ins.count("%s") && outs.count("%s")) {
      HandleViewBetweenInputAndOutput(ins["%s"][0], outs["%s"][0]);
    })";

113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
const char* INPLACE_DUPLICABLE_INPUT = R"([0])";

const char* INPLACE_LEAF_ERROR_MESSAGE = R"(Leaf Var (%s) that doesn't stop gradient can't use inplace strategy.)";

const char* INPLACE_STRATEGY_TEMPLATE =
R"(
    PADDLE_ENFORCE_EQ(
      %s->IsLeaf() && !%s->OverridedStopGradient(), false,
      platform::errors::InvalidArgument("%s", %s->Name()));
    %s->BumpInplaceVersion();
    VLOG(3) << "Var(" << %s->Name() << ") uses Inplace Strategy.";
)";

const char* INPLACE_MAPPING_TEMPLATE = R"({"%s", "%s"})";

128
const char* OP_FUNCTION_TEMPLATE =
129
R"(
130
static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs)
131
{
132 133
  PyThreadState *tstate = nullptr;
  try
134
  {
Z
zyfncg 已提交
135
    std::string op_type = "%s";
136
    platform::RecordEvent op_type_record_event("%s pybind_imperative_func");
137 138
    %s
    framework::AttributeMap attrs;
Z
zyfncg 已提交
139
    ConstructAttrMapFromPyArgs(op_type, args, %d, PyTuple_GET_SIZE(args) , attrs);
140
    tstate = PyEval_SaveThread();
141
    %s
142 143 144
    imperative::NameVarBaseMap outs = %s;
    imperative::NameVarBaseMap ins = %s;
    %s
Z
zyfncg 已提交
145
    imperative::GetCurrentTracer()->TraceOp(op_type, ins, outs, attrs, {%s});
146 147
    PyEval_RestoreThread(tstate);
    tstate = nullptr;
148
    %s
149
  }
150 151 152 153 154 155 156
  catch(...) {
    if (tstate) {
      PyEval_RestoreThread(tstate);
    }
    ThrowExceptionToPython(std::current_exception());
    return nullptr;
  }
157
})";
158

159
const char* PYBIND_ITEM_TEMPLATE = R"(  {"%s", (PyCFunction)(void(*)(void))%s, METH_VARARGS | METH_KEYWORDS, "C++ interface function for %s in dygraph."},)";
160

161
// clang-format on
L
Leo Chen 已提交
162 163
static inline bool FindInsMap(const std::string& op_type,
                              const std::string& in_name) {
164 165 166
  return op_ins_map[op_type].count(in_name);
}

L
Leo Chen 已提交
167 168 169 170 171 172 173 174
static inline bool FindOutsMap(const std::string& op_type,
                               const std::string& out_name) {
  return op_outs_map[op_type].count(out_name);
}

static inline bool FindPassingOutsMap(const std::string& op_type,
                                      const std::string& out_name) {
  return op_passing_outs_map[op_type].count(out_name);
175
}
176

177 178 179 180
static inline bool FindDuplicableInputInplaceOpSet(const std::string& op_type) {
  return inplace_op_duplicable_ins_set.count(op_type);
}

181 182 183 184
static inline bool FindViewOpMap(const std::string& op_type) {
  return view_op_map.count(op_type);
}

185 186 187 188
static inline std::string TempName(const std::string& name) {
  return name + '_';
}

189
std::string GenerateOpFunctionsBody(
190 191
    const paddle::framework::proto::OpProto* op_proto,
    std::string func_name,
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
    bool use_inplace_strategy = false,
    std::map<std::string, std::string> inplace_map = {}) {
  auto& op_type = op_proto->type();
  std::string input_args = "";
  std::string ins_initializer = "{";
  std::string ins_initializer_with_null = "";
  std::string py_arg = "";
  int arg_idx = 0;
  int input_args_num = 0;
  std::string ins_cast_str = "";
  std::string view_strategy_str = "";
  std::string inplace_strategy_str = "";
  for (auto& input : op_proto->inputs()) {
    auto& in_name = input.name();
    // skip those dispensable inputs, like ResidualData in conv2d
    if (input.dispensable() && !FindInsMap(op_type, in_name)) {
      continue;
    }
    const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE;
J
Jiabin Yang 已提交
211 212
    auto input_arg = paddle::string::Sprintf(
        ARG_TEMPLATE, in_type, LegalizeVarName(TempName(in_name)));
213 214 215 216 217 218
    input_args += input_arg;
    input_args += ",";
    input_args_num++;
    const auto in_cast_type =
        input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
    auto dispensable = input.dispensable() ? "true" : "false";
219 220 221 222 223
    ins_cast_str += paddle::string::Sprintf(in_cast_type,
                                            LegalizeVarName(in_name),
                                            in_name,
                                            arg_idx++,
                                            dispensable);
224 225 226 227 228 229

    if (input.dispensable()) {
      const auto in_template = input.duplicable()
                                   ? INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
                                   : INPUT_INITIALIZER_TEMPLATE_WITH_NULL;
      ins_initializer_with_null +=
230 231 232 233
          paddle::string::Sprintf(in_template,
                                  LegalizeVarName(in_name),
                                  in_name,
                                  LegalizeVarName(in_name));
234 235 236 237
    } else {
      const auto in_template = input.duplicable()
                                   ? INPUT_LIST_INITIALIZER_TEMPLATE
                                   : INPUT_INITIALIZER_TEMPLATE;
238 239
      ins_initializer += paddle::string::Sprintf(
          in_template, in_name, LegalizeVarName(in_name));
240 241 242 243 244 245 246 247
      ins_initializer += ",";
    }
  }
  if (ins_initializer.back() == ',') {
    ins_initializer.pop_back();
  }
  ins_initializer += "}";

248
  if (!input_args.empty() && input_args.back() == ',') {
249 250 251 252 253 254 255 256 257 258 259 260
    input_args.pop_back();
  }

  // Generate outs initializer
  std::string outs_initializer = "{";
  std::string outs_initializer_with_null = "";
  std::string inplace_mapping_str = "";
  std::string return_str = "";

  int outs_num = 0;
  for (auto& output : op_proto->outputs()) {
    auto& out_name = output.name();
261

262 263 264 265 266 267 268 269 270 271 272 273 274 275
    // skip those dispensable oututs
    if (output.dispensable() && !FindOutsMap(op_type, out_name)) {
      continue;
    }
    const auto out_type =
        output.duplicable() ? OUT_VAR_LIST_TYPE : OUT_VAR_TYPE;
    const auto return_template =
        output.duplicable() ? RETURN_LIST_TEMPLATE : RETURN_TEMPLATE;

    if (FindPassingOutsMap(op_type, out_name)) {
      if (input_args != "") {
        input_args += ",";
      }
      input_args += out_type;
J
Jiabin Yang 已提交
276
      input_args += LegalizeVarName(out_name);
277 278 279 280 281 282 283 284 285 286 287 288
      input_args_num++;

      if (output.dispensable()) {
        const auto out_template =
            output.duplicable() ? OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
                                : OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL;
        outs_initializer_with_null +=
            paddle::string::Sprintf(out_template, out_name, out_name);
      } else {
        const auto out_template = output.duplicable()
                                      ? INPUT_LIST_INITIALIZER_TEMPLATE
                                      : INPUT_INITIALIZER_TEMPLATE;
289 290
        outs_initializer += paddle::string::Sprintf(
            out_template, out_name, LegalizeVarName(out_name));
291 292
        outs_initializer += ",";
      }
293 294 295 296

      const auto in_cast_type =
          output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
      auto dispensable = output.dispensable() ? "true" : "false";
297 298 299 300 301
      ins_cast_str += paddle::string::Sprintf(in_cast_type,
                                              LegalizeVarName(out_name),
                                              out_name,
                                              arg_idx++,
                                              dispensable);
302 303
    } else if (use_inplace_strategy && inplace_map.count(out_name)) {
      PADDLE_ENFORCE_NE(
304 305
          inplace_map[out_name],
          "",
306
          paddle::platform::errors::InvalidArgument(
307 308
              "Inplace op %s has no input corresponding to output %s.",
              op_type,
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
              out_name));

      // TODO(pangyoki): Inplace op don't have duplicable output in temporary,
      // so don't support duplicable output now.
      const auto out_template = INPUT_INITIALIZER_TEMPLATE;

      auto inplace_input_name = inplace_map[out_name];
      inplace_mapping_str += paddle::string::Sprintf(
          INPLACE_MAPPING_TEMPLATE, inplace_input_name, out_name);
      inplace_mapping_str += ",";

      // If inplace op has duplicable input, the first Varbase in input will
      // share Varbase with output.
      if (FindDuplicableInputInplaceOpSet(op_type)) {
        inplace_input_name += INPLACE_DUPLICABLE_INPUT;
      }

      // Leaf Var that doesn't stop gradient can't use inplace strategy.
      // Increase inplace_version.
328 329 330 331 332 333 334 335
      inplace_strategy_str +=
          paddle::string::Sprintf(INPLACE_STRATEGY_TEMPLATE,
                                  LegalizeVarName(inplace_input_name),
                                  LegalizeVarName(inplace_input_name),
                                  INPLACE_LEAF_ERROR_MESSAGE,
                                  LegalizeVarName(inplace_input_name),
                                  LegalizeVarName(inplace_input_name),
                                  LegalizeVarName(inplace_input_name));
J
Jiabin Yang 已提交
336 337
      outs_initializer += paddle::string::Sprintf(
          out_template, out_name, LegalizeVarName(inplace_input_name));
338 339 340 341 342 343 344 345 346
      outs_initializer += ",";
    } else {
      // There are few Operators that have duplicable output, like `Out` in
      // split op. We need to specify the number of variables for the
      // duplicable output, as the argument OutNum;
      if (output.duplicable()) {
        if (input_args != "") {
          input_args += ",";
        }
J
Jiabin Yang 已提交
347 348
        auto out_num_str =
            paddle::string::Sprintf(ARG_OUT_NUM, LegalizeVarName(out_name));
349 350 351 352 353
        input_args += ARG_OUT_NUM_TYPE;
        input_args += out_num_str;
        input_args_num++;
        outs_initializer += paddle::string::Sprintf(
            OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str);
354 355

        auto dispensable = output.dispensable() ? "true" : "false";
356 357 358 359 360
        ins_cast_str += paddle::string::Sprintf(CAST_SIZE_T_TEMPLATE,
                                                out_num_str,
                                                out_num_str,
                                                arg_idx++,
                                                dispensable);
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
      } else {
        outs_initializer +=
            paddle::string::Sprintf(OUT_INITIALIZER_TEMPLATE, out_name);
      }
      outs_initializer += ",";
    }

    return_str += paddle::string::Sprintf(return_template, out_name);
    return_str += ",";
    outs_num += 1;
  }
  if (outs_initializer.back() == ',') {
    outs_initializer.pop_back();
    return_str.pop_back();
  }
  outs_initializer += "}";
377
  if (!inplace_mapping_str.empty() && inplace_mapping_str.back() == ',') {
378 379 380 381 382
    inplace_mapping_str.pop_back();
  }
  if (!use_inplace_strategy && FindViewOpMap(op_type)) {
    std::string viwe_input_name = view_op_map[op_type].first;
    std::string viwe_output_name = view_op_map[op_type].second;
383 384 385 386 387 388
    view_strategy_str +=
        paddle::string::Sprintf(HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT,
                                viwe_input_name,
                                viwe_output_name,
                                viwe_input_name,
                                viwe_output_name);
389 390
  }
  if (outs_num == 0) {
391
    return_str = "RETURN_PY_NONE";
392
  } else if (outs_num == 1) {
393
    return_str = "return MakeReturnPyObject(" + return_str + ");";
394
  } else {
395
    return_str = "return MakeReturnPyObject(" +
396
                 paddle::string::Sprintf(RETURN_TUPLE_TEMPLATE, return_str) +
397
                 ");";
398 399 400 401 402 403 404 405 406 407
  }
  std::string function_args = "";
  if (input_args == "") {
    function_args = FUNCTION_ARGS_NO_INPUT;
  } else {
    function_args = paddle::string::Sprintf(FUNCTION_ARGS, input_args);
  }

  // generate op funtcion body
  auto op_function_str = paddle::string::Sprintf(
408 409 410 411 412 413 414 415 416
      OP_FUNCTION_TEMPLATE,
      func_name,
      op_type,
      op_type,
      ins_cast_str,
      input_args_num,
      inplace_strategy_str,
      outs_initializer,
      ins_initializer,
417 418
      ins_initializer_with_null + outs_initializer_with_null +
          view_strategy_str,
419 420
      inplace_mapping_str,
      return_str);
421 422 423 424

  return op_function_str;
}

425
static std::tuple<std::vector<std::string>, std::vector<std::string>>
426
GenerateOpFunctions() {
427 428
  auto& op_info_map = paddle::framework::OpInfoMap::Instance().map();

429
  std::vector<std::string> op_function_list, bind_function_list;
430 431
  auto& all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels();

432 433 434 435 436 437 438
  for (auto& pair : op_info_map) {
    auto& op_info = pair.second;
    auto op_proto = op_info.proto_;
    if (op_proto == nullptr) {
      continue;
    }
    auto& op_type = op_proto->type();
439
    // Skip operator which is not inherit form OperatorWithKernel, like while,
440
    // since only OperatorWithKernel can run in dygraph mode.
441
    // if the phi lib contains op kernel, we still generate ops method
442
    if (!all_kernels.count(op_type) &&
443
        !phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type)) {
444 445
      continue;
    }
446

447 448 449 450 451 452 453 454 455 456 457 458 459 460
    // NOTE(pangyoki): Inplace Strategy.
    // In this case, output will reuse input varbase.
    // Dygraph mode needs to be aligned with the in-place strategy in static
    // mode, and the mapping relationships between output and input that have
    // been defined in static mode should be used in dygraph mode.
    // Find which ops need to use Inplace strategy in static mode, and get the
    // mapping relationship between Inplace output and input.
    auto& infer_inplace =
        paddle::framework::OpInfoMap::Instance().Get(op_type).infer_inplace_;
    std::map<std::string, std::string> inplace_map;
    if (infer_inplace) {
      auto in_to_outs = infer_inplace(true);
      for (auto& inplace_pair : in_to_outs) {
        inplace_map[inplace_pair.second] = inplace_pair.first;
461 462
      }
    }
463

464
    std::string func_name = "imperative_" + op_type;
465
    std::string op_function_str = GenerateOpFunctionsBody(op_proto, func_name);
466 467

    // generate pybind item
468
    auto bind_function_str = paddle::string::Sprintf(
469
        PYBIND_ITEM_TEMPLATE, op_type, func_name, op_type);
470 471 472

    op_function_list.emplace_back(std::move(op_function_str));
    bind_function_list.emplace_back(std::move(bind_function_str));
473 474 475 476 477 478 479 480 481 482 483

    if (infer_inplace) {
      // Reuse Varbase Inplace OP: op_type_.
      // The inplace OP needs a new implementation method.
      std::string inplace_op_type = op_type + "_";
      std::string inplace_func_name = "imperative_" + inplace_op_type;
      std::string inplace_op_function_str = GenerateOpFunctionsBody(
          op_proto, inplace_func_name, true, inplace_map);

      // generate pybind item
      auto inplace_bind_function_str =
484 485 486 487
          paddle::string::Sprintf(PYBIND_ITEM_TEMPLATE,
                                  inplace_op_type,
                                  inplace_func_name,
                                  inplace_op_type);
488 489 490 491

      op_function_list.emplace_back(std::move(inplace_op_function_str));
      bind_function_list.emplace_back(std::move(inplace_bind_function_str));
    }
492
  }
493
  return std::make_tuple(op_function_list, bind_function_list);
494 495 496 497 498 499 500 501
}

int main(int argc, char* argv[]) {
  if (argc != 2) {
    std::cerr << "argc must be 2" << std::endl;
    return -1;
  }

502
#ifdef PADDLE_WITH_ASCEND_CL
503 504 505 506
  auto ascend_ptr = paddle::framework::AscendInstance::GetInstance();
  ascend_ptr->InitGEForUT();
#endif

507
  std::vector<std::string> headers{"\"paddle/fluid/imperative/tracer.h\"",
508
                                   "\"paddle/fluid/platform/profiler.h\"",
509 510
                                   "\"pybind11/numpy.h\"",
                                   "\"pybind11/pybind11.h\"",
511
                                   "\"pybind11/detail/common.h\"",
512 513
                                   "\"paddle/fluid/pybind/eager_utils.h\"",
                                   "\"paddle/fluid/pybind/op_function.h\"",
514
                                   "<Python.h>"};
515 516 517 518 519 520 521

  std::ofstream out(argv[1], std::ios::out);

  for (auto& header : headers) {
    out << "#include  " + header + "\n";
  }

522 523 524
  out << "\n\n";

  auto op_funcs = GenerateOpFunctions();
525

526
  out << "namespace paddle {\n"
527 528
      << "namespace pybind {\n\n";
  out << "std::atomic<int> VarBaseUniqueNameID{0};\n";
529 530
  out << paddle::string::join_strings(std::get<0>(op_funcs), '\n');
  out << "\n\n";
531

532 533 534 535
  out << "static PyMethodDef ExtestMethods[] = {\n"
      << paddle::string::join_strings(std::get<1>(op_funcs), '\n')
      << "\n  {nullptr,nullptr,0,nullptr}"
      << "};\n\n";
536

537
  out << "void BindOpFunctions(pybind11::module *module) {\n"
538 539 540 541 542 543 544
      << "  auto m = module->def_submodule(\"ops\");\n"
      << "  if (PyModule_AddFunctions(m.ptr(), ExtestMethods) < 0) {\n"
      << "    PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
         "core.ops failed!\"));\n"
      << "  }\n\n"
      << "  InitOpsAttrTypeMap();"
      << "}\n\n"
545 546 547 548
      << "} // namespace pybind\n"
      << "} // namespace paddle\n";

  out.close();
549

550
#ifdef PADDLE_WITH_ASCEND_CL
551 552
  ge::GEFinalize();
#endif
553

554 555
  return 0;
}