op_function_generator.cc 18.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/pybind/op_function_generator.h"

17
#include <algorithm>
18 19 20
#include <fstream>
#include <iostream>
#include <string>
21 22 23
#ifndef _WIN32
#include <unistd.h>
#endif
24 25 26 27 28 29 30

#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h"
31
#ifdef PADDLE_WITH_ASCEND_CL
32 33
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
#endif
34

35
// phi
36
#include "paddle/phi/kernels/declarations.h"
37

J
Jiabin Yang 已提交
38 39 40 41 42 43
static std::string LegalizeVarName(const std::string& var_name) {
  std::string ret = var_name;
  std::replace(ret.begin(), ret.end(), '@', '_');  // replace all '-' to '_'
  return ret;
}

44 45 46 47 48 49 50 51
// NOTE(pangyoki): Inplace OP with duplicable input.
// The set includes inplace ops that have duplicable input.
// The first Varbase in input needs to be specified for the inplace strategy
// and share Varbase with the output.
std::set<std::string> inplace_op_duplicable_ins_set = {
    "sum",
};

52
// clang-format off
53
const char* OUT_INITIALIZER_TEMPLATE =
54
    R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase("auto_"+std::to_string(VarBaseUniqueNameID++)+"_"))}})";
55 56 57 58
const char* OUT_DUPLICABLE_INITIALIZER_TEMPLATE = R"({"%s", ConstructDuplicableOutput(%s)})";

const char* INPUT_INITIALIZER_TEMPLATE = R"({"%s", {%s}})";
const char* INPUT_LIST_INITIALIZER_TEMPLATE = R"({"%s", %s})";
L
Leo Chen 已提交
59

60 61 62 63
const char* INPUT_INITIALIZER_TEMPLATE_WITH_NULL = R"(
    if (%s != nullptr) {
      ins["%s"] = {%s};
    }
64
)";
L
Leo Chen 已提交
65

66
const char* INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"(
L
Leo Chen 已提交
67
    if (%s.size() != 0) {
68 69
      ins["%s"] = %s;
    }
L
Leo Chen 已提交
70 71
)";

72 73
const char* OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL = R"(
    outs["%s"] = {%s};
74 75
)";

76 77
const char* OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"(
    outs["%s"] = %s;
L
Leo Chen 已提交
78
)";
79 80 81 82
// if inputs is list, no need {}
const char* ARG_OUT_NUM = R"(%sNum)";
const char* ARG_OUT_NUM_TYPE = R"(size_t )";

83 84 85 86 87 88 89
const char* IN_VAR_TYPE = R"(py::handle)";
const char* IN_VAR_LIST_TYPE = R"(py::handle)";

const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";

const char* CAST_VAR_TEMPLATE = R"(
Z
zyfncg 已提交
90
    auto %s = GetVarBaseFromArgs(op_type, "%s", args, %d, %s);)";
91 92

const char* CAST_VAR_LIST_TEMPLATE = R"(
Z
zyfncg 已提交
93
    auto %s = GetVarBaseListFromArgs(op_type, "%s", args, %d, %s);)";
94

95
const char* CAST_SIZE_T_TEMPLATE = R"(
Z
zyfncg 已提交
96
    auto %s = GetUnsignedLongFromArgs(op_type, "%s", args, %d, %s);)";
97

98 99 100 101 102 103 104 105 106
const char* ARG_TEMPLATE = R"(const %s& %s)";

const char* RETURN_TUPLE_TYPE = R"(std::tuple<%s>)";
const char* RETURN_TUPLE_TEMPLATE = R"(std::make_tuple(%s))";
const char* RETURN_LIST_TEMPLATE = R"(outs["%s"])";
const char* RETURN_TEMPLATE = R"(outs["%s"][0])";

const char* FUNCTION_ARGS = R"(%s, const py::args& args)";
const char* FUNCTION_ARGS_NO_INPUT = R"(const py::args& args)";
107

108
const char* HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT = R"(
109 110 111 112
    if (ins.count("%s") && outs.count("%s")) {
      HandleViewBetweenInputAndOutput(ins["%s"][0], outs["%s"][0]);
    })";

113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
const char* INPLACE_DUPLICABLE_INPUT = R"([0])";

const char* INPLACE_LEAF_ERROR_MESSAGE = R"(Leaf Var (%s) that doesn't stop gradient can't use inplace strategy.)";

const char* INPLACE_STRATEGY_TEMPLATE =
R"(
    PADDLE_ENFORCE_EQ(
      %s->IsLeaf() && !%s->OverridedStopGradient(), false,
      platform::errors::InvalidArgument("%s", %s->Name()));
    %s->BumpInplaceVersion();
    VLOG(3) << "Var(" << %s->Name() << ") uses Inplace Strategy.";
)";

const char* INPLACE_MAPPING_TEMPLATE = R"({"%s", "%s"})";

128
const char* OP_FUNCTION_TEMPLATE =
129
R"(
130
static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs)
131
{
132 133
  PyThreadState *tstate = nullptr;
  try
134
  {
Z
zyfncg 已提交
135
    std::string op_type = "%s";
136
    platform::RecordEvent op_type_record_event("%s pybind_imperative_func");
137 138
    %s
    framework::AttributeMap attrs;
Z
zyfncg 已提交
139
    ConstructAttrMapFromPyArgs(op_type, args, %d, PyTuple_GET_SIZE(args) , attrs);
140
    tstate = PyEval_SaveThread();
141
    %s
142 143 144
    imperative::NameVarBaseMap outs = %s;
    imperative::NameVarBaseMap ins = %s;
    %s
Z
zyfncg 已提交
145
    imperative::GetCurrentTracer()->TraceOp(op_type, ins, outs, attrs, {%s});
146 147
    PyEval_RestoreThread(tstate);
    tstate = nullptr;
148
    %s
149
  }
150 151 152 153 154 155 156
  catch(...) {
    if (tstate) {
      PyEval_RestoreThread(tstate);
    }
    ThrowExceptionToPython(std::current_exception());
    return nullptr;
  }
157
})";
158

159
const char* PYBIND_ITEM_TEMPLATE = R"(  {"%s", (PyCFunction)(void(*)(void))%s, METH_VARARGS | METH_KEYWORDS, "C++ interface function for %s in dygraph."},)";
160

161
// clang-format on
L
Leo Chen 已提交
162 163
static inline bool FindInsMap(const std::string& op_type,
                              const std::string& in_name) {
164 165 166
  return op_ins_map[op_type].count(in_name);
}

L
Leo Chen 已提交
167 168 169 170 171 172 173 174
static inline bool FindOutsMap(const std::string& op_type,
                               const std::string& out_name) {
  return op_outs_map[op_type].count(out_name);
}

static inline bool FindPassingOutsMap(const std::string& op_type,
                                      const std::string& out_name) {
  return op_passing_outs_map[op_type].count(out_name);
175
}
176

177 178 179 180
static inline bool FindDuplicableInputInplaceOpSet(const std::string& op_type) {
  return inplace_op_duplicable_ins_set.count(op_type);
}

181 182 183 184
static inline bool FindViewOpMap(const std::string& op_type) {
  return view_op_map.count(op_type);
}

185 186 187 188
static inline std::string TempName(const std::string& name) {
  return name + '_';
}

189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
std::string GenerateOpFunctionsBody(
    const paddle::framework::proto::OpProto* op_proto, std::string func_name,
    bool use_inplace_strategy = false,
    std::map<std::string, std::string> inplace_map = {}) {
  auto& op_type = op_proto->type();
  std::string input_args = "";
  std::string ins_initializer = "{";
  std::string ins_initializer_with_null = "";
  std::string py_arg = "";
  int arg_idx = 0;
  int input_args_num = 0;
  std::string ins_cast_str = "";
  std::string view_strategy_str = "";
  std::string inplace_strategy_str = "";
  for (auto& input : op_proto->inputs()) {
    auto& in_name = input.name();
    // skip those dispensable inputs, like ResidualData in conv2d
    if (input.dispensable() && !FindInsMap(op_type, in_name)) {
      continue;
    }
    const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE;
J
Jiabin Yang 已提交
210 211
    auto input_arg = paddle::string::Sprintf(
        ARG_TEMPLATE, in_type, LegalizeVarName(TempName(in_name)));
212 213 214 215 216 217
    input_args += input_arg;
    input_args += ",";
    input_args_num++;
    const auto in_cast_type =
        input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
    auto dispensable = input.dispensable() ? "true" : "false";
J
Jiabin Yang 已提交
218 219 220
    ins_cast_str +=
        paddle::string::Sprintf(in_cast_type, LegalizeVarName(in_name), in_name,
                                arg_idx++, dispensable);
221 222 223 224 225 226

    if (input.dispensable()) {
      const auto in_template = input.duplicable()
                                   ? INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
                                   : INPUT_INITIALIZER_TEMPLATE_WITH_NULL;
      ins_initializer_with_null +=
J
Jiabin Yang 已提交
227 228
          paddle::string::Sprintf(in_template, LegalizeVarName(in_name),
                                  in_name, LegalizeVarName(in_name));
229 230 231 232
    } else {
      const auto in_template = input.duplicable()
                                   ? INPUT_LIST_INITIALIZER_TEMPLATE
                                   : INPUT_INITIALIZER_TEMPLATE;
J
Jiabin Yang 已提交
233 234
      ins_initializer += paddle::string::Sprintf(in_template, in_name,
                                                 LegalizeVarName(in_name));
235 236 237 238 239 240 241 242
      ins_initializer += ",";
    }
  }
  if (ins_initializer.back() == ',') {
    ins_initializer.pop_back();
  }
  ins_initializer += "}";

243
  if (!input_args.empty() && input_args.back() == ',') {
244 245 246 247 248 249 250 251 252 253 254 255
    input_args.pop_back();
  }

  // Generate outs initializer
  std::string outs_initializer = "{";
  std::string outs_initializer_with_null = "";
  std::string inplace_mapping_str = "";
  std::string return_str = "";

  int outs_num = 0;
  for (auto& output : op_proto->outputs()) {
    auto& out_name = output.name();
256

257 258 259 260 261 262 263 264 265 266 267 268 269 270
    // skip those dispensable oututs
    if (output.dispensable() && !FindOutsMap(op_type, out_name)) {
      continue;
    }
    const auto out_type =
        output.duplicable() ? OUT_VAR_LIST_TYPE : OUT_VAR_TYPE;
    const auto return_template =
        output.duplicable() ? RETURN_LIST_TEMPLATE : RETURN_TEMPLATE;

    if (FindPassingOutsMap(op_type, out_name)) {
      if (input_args != "") {
        input_args += ",";
      }
      input_args += out_type;
J
Jiabin Yang 已提交
271
      input_args += LegalizeVarName(out_name);
272 273 274 275 276 277 278 279 280 281 282 283
      input_args_num++;

      if (output.dispensable()) {
        const auto out_template =
            output.duplicable() ? OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
                                : OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL;
        outs_initializer_with_null +=
            paddle::string::Sprintf(out_template, out_name, out_name);
      } else {
        const auto out_template = output.duplicable()
                                      ? INPUT_LIST_INITIALIZER_TEMPLATE
                                      : INPUT_INITIALIZER_TEMPLATE;
J
Jiabin Yang 已提交
284 285
        outs_initializer += paddle::string::Sprintf(out_template, out_name,
                                                    LegalizeVarName(out_name));
286 287
        outs_initializer += ",";
      }
288 289 290 291

      const auto in_cast_type =
          output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
      auto dispensable = output.dispensable() ? "true" : "false";
J
Jiabin Yang 已提交
292 293 294
      ins_cast_str +=
          paddle::string::Sprintf(in_cast_type, LegalizeVarName(out_name),
                                  out_name, arg_idx++, dispensable);
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
    } else if (use_inplace_strategy && inplace_map.count(out_name)) {
      PADDLE_ENFORCE_NE(
          inplace_map[out_name], "",
          paddle::platform::errors::InvalidArgument(
              "Inplace op %s has no input corresponding to output %s.", op_type,
              out_name));

      // TODO(pangyoki): Inplace op don't have duplicable output in temporary,
      // so don't support duplicable output now.
      const auto out_template = INPUT_INITIALIZER_TEMPLATE;

      auto inplace_input_name = inplace_map[out_name];
      inplace_mapping_str += paddle::string::Sprintf(
          INPLACE_MAPPING_TEMPLATE, inplace_input_name, out_name);
      inplace_mapping_str += ",";

      // If inplace op has duplicable input, the first Varbase in input will
      // share Varbase with output.
      if (FindDuplicableInputInplaceOpSet(op_type)) {
        inplace_input_name += INPLACE_DUPLICABLE_INPUT;
      }

      // Leaf Var that doesn't stop gradient can't use inplace strategy.
      // Increase inplace_version.
      inplace_strategy_str += paddle::string::Sprintf(
J
Jiabin Yang 已提交
320 321 322 323 324 325 326
          INPLACE_STRATEGY_TEMPLATE, LegalizeVarName(inplace_input_name),
          LegalizeVarName(inplace_input_name), INPLACE_LEAF_ERROR_MESSAGE,
          LegalizeVarName(inplace_input_name),
          LegalizeVarName(inplace_input_name),
          LegalizeVarName(inplace_input_name));
      outs_initializer += paddle::string::Sprintf(
          out_template, out_name, LegalizeVarName(inplace_input_name));
327 328 329 330 331 332 333 334 335
      outs_initializer += ",";
    } else {
      // There are few Operators that have duplicable output, like `Out` in
      // split op. We need to specify the number of variables for the
      // duplicable output, as the argument OutNum;
      if (output.duplicable()) {
        if (input_args != "") {
          input_args += ",";
        }
J
Jiabin Yang 已提交
336 337
        auto out_num_str =
            paddle::string::Sprintf(ARG_OUT_NUM, LegalizeVarName(out_name));
338 339 340 341 342
        input_args += ARG_OUT_NUM_TYPE;
        input_args += out_num_str;
        input_args_num++;
        outs_initializer += paddle::string::Sprintf(
            OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str);
343 344 345

        auto dispensable = output.dispensable() ? "true" : "false";
        ins_cast_str +=
Z
zyfncg 已提交
346
            paddle::string::Sprintf(CAST_SIZE_T_TEMPLATE, out_num_str,
347
                                    out_num_str, arg_idx++, dispensable);
348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
      } else {
        outs_initializer +=
            paddle::string::Sprintf(OUT_INITIALIZER_TEMPLATE, out_name);
      }
      outs_initializer += ",";
    }

    return_str += paddle::string::Sprintf(return_template, out_name);
    return_str += ",";
    outs_num += 1;
  }
  if (outs_initializer.back() == ',') {
    outs_initializer.pop_back();
    return_str.pop_back();
  }
  outs_initializer += "}";
364
  if (!inplace_mapping_str.empty() && inplace_mapping_str.back() == ',') {
365 366 367 368 369 370 371 372 373 374
    inplace_mapping_str.pop_back();
  }
  if (!use_inplace_strategy && FindViewOpMap(op_type)) {
    std::string viwe_input_name = view_op_map[op_type].first;
    std::string viwe_output_name = view_op_map[op_type].second;
    view_strategy_str += paddle::string::Sprintf(
        HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT, viwe_input_name, viwe_output_name,
        viwe_input_name, viwe_output_name);
  }
  if (outs_num == 0) {
375
    return_str = "Py_INCREF(Py_None);\n    return Py_None;";
376
  } else if (outs_num == 1) {
377
    return_str = "return MakeReturnPyObject(" + return_str + ");";
378
  } else {
379
    return_str = "return MakeReturnPyObject(" +
380
                 paddle::string::Sprintf(RETURN_TUPLE_TEMPLATE, return_str) +
381
                 ");";
382 383 384 385 386 387 388 389 390 391
  }
  std::string function_args = "";
  if (input_args == "") {
    function_args = FUNCTION_ARGS_NO_INPUT;
  } else {
    function_args = paddle::string::Sprintf(FUNCTION_ARGS, input_args);
  }

  // generate op funtcion body
  auto op_function_str = paddle::string::Sprintf(
Z
zyfncg 已提交
392
      OP_FUNCTION_TEMPLATE, func_name, op_type, op_type, ins_cast_str,
393
      input_args_num, inplace_strategy_str, outs_initializer, ins_initializer,
394 395
      ins_initializer_with_null + outs_initializer_with_null +
          view_strategy_str,
Z
zyfncg 已提交
396
      inplace_mapping_str, return_str);
397 398 399 400

  return op_function_str;
}

401
static std::tuple<std::vector<std::string>, std::vector<std::string>>
402
GenerateOpFunctions() {
403 404
  auto& op_info_map = paddle::framework::OpInfoMap::Instance().map();

405
  std::vector<std::string> op_function_list, bind_function_list;
406 407
  auto& all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels();

408 409 410 411 412 413 414
  for (auto& pair : op_info_map) {
    auto& op_info = pair.second;
    auto op_proto = op_info.proto_;
    if (op_proto == nullptr) {
      continue;
    }
    auto& op_type = op_proto->type();
415
    // Skip operator which is not inherit form OperatorWithKernel, like while,
416
    // since only OperatorWithKernel can run in dygraph mode.
417
    // if the phi lib contains op kernel, we still generate ops method
418
    if (!all_kernels.count(op_type) &&
419
        !phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type)) {
420 421
      continue;
    }
422

423 424 425 426 427 428 429 430 431 432 433 434 435 436
    // NOTE(pangyoki): Inplace Strategy.
    // In this case, output will reuse input varbase.
    // Dygraph mode needs to be aligned with the in-place strategy in static
    // mode, and the mapping relationships between output and input that have
    // been defined in static mode should be used in dygraph mode.
    // Find which ops need to use Inplace strategy in static mode, and get the
    // mapping relationship between Inplace output and input.
    auto& infer_inplace =
        paddle::framework::OpInfoMap::Instance().Get(op_type).infer_inplace_;
    std::map<std::string, std::string> inplace_map;
    if (infer_inplace) {
      auto in_to_outs = infer_inplace(true);
      for (auto& inplace_pair : in_to_outs) {
        inplace_map[inplace_pair.second] = inplace_pair.first;
437 438
      }
    }
439

440
    std::string func_name = "imperative_" + op_type;
441
    std::string op_function_str = GenerateOpFunctionsBody(op_proto, func_name);
442 443

    // generate pybind item
444
    auto bind_function_str = paddle::string::Sprintf(
445
        PYBIND_ITEM_TEMPLATE, op_type, func_name, op_type);
446 447 448

    op_function_list.emplace_back(std::move(op_function_str));
    bind_function_list.emplace_back(std::move(bind_function_str));
449 450 451 452 453 454 455 456 457 458 459

    if (infer_inplace) {
      // Reuse Varbase Inplace OP: op_type_.
      // The inplace OP needs a new implementation method.
      std::string inplace_op_type = op_type + "_";
      std::string inplace_func_name = "imperative_" + inplace_op_type;
      std::string inplace_op_function_str = GenerateOpFunctionsBody(
          op_proto, inplace_func_name, true, inplace_map);

      // generate pybind item
      auto inplace_bind_function_str =
460 461
          paddle::string::Sprintf(PYBIND_ITEM_TEMPLATE, inplace_op_type,
                                  inplace_func_name, inplace_op_type);
462 463 464 465

      op_function_list.emplace_back(std::move(inplace_op_function_str));
      bind_function_list.emplace_back(std::move(inplace_bind_function_str));
    }
466
  }
467
  return std::make_tuple(op_function_list, bind_function_list);
468 469 470 471 472 473 474 475
}

int main(int argc, char* argv[]) {
  if (argc != 2) {
    std::cerr << "argc must be 2" << std::endl;
    return -1;
  }

476
#ifdef PADDLE_WITH_ASCEND_CL
477 478 479 480
  auto ascend_ptr = paddle::framework::AscendInstance::GetInstance();
  ascend_ptr->InitGEForUT();
#endif

481
  std::vector<std::string> headers{"\"paddle/fluid/imperative/tracer.h\"",
482
                                   "\"paddle/fluid/platform/profiler.h\"",
483 484
                                   "\"pybind11/detail/common.h\"",
                                   "<Python.h>"};
485 486 487 488 489 490 491 492 493

  std::ofstream out(argv[1], std::ios::out);

  out << "#pragma once\n\n";

  for (auto& header : headers) {
    out << "#include  " + header + "\n";
  }

494 495 496
  out << "\n\n";

  auto op_funcs = GenerateOpFunctions();
497

498
  out << "namespace paddle {\n"
499 500
      << "namespace pybind {\n\n";
  out << "std::atomic<int> VarBaseUniqueNameID{0};\n";
501 502
  out << paddle::string::join_strings(std::get<0>(op_funcs), '\n');
  out << "\n\n";
503

504 505 506 507
  out << "static PyMethodDef ExtestMethods[] = {\n"
      << paddle::string::join_strings(std::get<1>(op_funcs), '\n')
      << "\n  {nullptr,nullptr,0,nullptr}"
      << "};\n\n";
508

509 510 511 512 513 514 515 516
  out << "inline void BindOpFunctions(pybind11::module *module) {\n"
      << "  auto m = module->def_submodule(\"ops\");\n"
      << "  if (PyModule_AddFunctions(m.ptr(), ExtestMethods) < 0) {\n"
      << "    PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
         "core.ops failed!\"));\n"
      << "  }\n\n"
      << "  InitOpsAttrTypeMap();"
      << "}\n\n"
517 518 519 520
      << "} // namespace pybind\n"
      << "} // namespace paddle\n";

  out.close();
521

522
#ifdef PADDLE_WITH_ASCEND_CL
523 524
  ge::GEFinalize();
#endif
525

526 527
  return 0;
}