op_function_generator.cc 18.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include <algorithm>
16 17 18
#include <fstream>
#include <iostream>
#include <string>
19
#include <unistd.h>
20 21 22 23 24 25 26

#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h"
27 28 29
#ifdef PADDLE_WITH_ASCEND
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
#endif
30

L
Leo Chen 已提交
31 32 33 34 35 36 37 38
// NOTE(zhiqiu): Commonly, the inputs in auto-generated OP function are
// determined by the OP`s proto automatically, i.e., all the inputs registered
// in OpMaker.
// However, some OPs have dispensable inputs, which means the input can
// be none for some conditions. It is discovered that most dispensable inputs
// is not used in imperative mode, so we drop those inputs when generating OP
// functions. While, for very few OPs, the dispensable inputs are used, we
// need to manually specify them in this map.
39 40
std::map<std::string, std::set<std::string>> op_ins_map = {
    {"layer_norm", {"X", "Scale", "Bias"}},
C
ceci3 已提交
41
    {"instance_norm", {"X", "Scale", "Bias"}},
42 43 44
    {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}},
    {"label_smooth", {"X", "PriorDist"}},
    {"assign", {"X"}},
45 46 47
    {"reshape2", {"X", "Shape"}},
    {"expand", {"X", "ExpandTimes"}},
    {"slice", {"Input", "StartsTensor", "EndsTensor"}},
L
Leo Chen 已提交
48 49
    {"fake_quantize_dequantize_moving_average_abs_max",
     {"X", "InScale", "InAccum", "InState"}},
50
    {"nll_loss", {"X", "Label", "Weight"}},
51
    {"bilinear_tensor_product", {"X", "Y", "Weight", "Bias"}},
52
    {"gather", {"X", "Index", "Axis"}},
53 54 55 56 57
    {"roi_pool", {"X", "ROIs", "RoisNum"}},
    {"roi_align", {"X", "ROIs", "RoisNum"}},
    {"collect_fpn_proposals",
     {"MultiLevelRois", "MultiLevelScores", "MultiLevelRoIsNum"}},
    {"distribute_fpn_proposals", {"FpnRois", "RoisNum"}},
58
    {"warpctc", {"Logits", "Label", "LogitsLength", "LabelLength"}},
59 60
    {"hierarchical_sigmoid",
     {"X", "W", "Label", "PathTable", "PathCode", "Bias"}},
61
    {"moving_average_abs_max_scale", {"X", "InAccum", "InState"}},
62
    {"multiclass_nms3", {"BBoxes", "Scores", "RoisNum"}},
63
    {"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}},
64
    {"momentum", {"Param", "Grad", "Velocity", "LearningRate"}},
65
    {"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}},
66
};
L
Leo Chen 已提交
67 68 69 70 71 72 73 74 75 76 77 78

// NOTE(zhiqiu): Like op_ins_map.
// Commonly, the outputs in auto-generated OP function are determined by the
// OP`s proto automatically, i.e., all the outputs registered in OpMaker.
// However, some OPs have dispensable outputs, which means the output can
// be none for some conditions. It is discovered that most dispensable outputs
// is not used in imperative mode, so we drop those outputs when generating OP
// functions. While, for very few OPs, the dispensable outputs are used, we
// need to manually specify them in this map.
std::map<std::string, std::set<std::string>> op_outs_map = {
    {"fake_quantize_dequantize_moving_average_abs_max",
     {"Out", "OutScale", "OutAccum", "OutState"}},
79 80 81
    {"batch_norm",
     {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance",
      "ReserveSpace"}},
C
ceci3 已提交
82 83 84
    {"sync_batch_norm",
     {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance",
      "ReserveSpace"}},
Z
Zhang Ting 已提交
85
    {"unique", {"Out", "Index", "Indices", "Counts"}},
86 87
    {"generate_proposals", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
    {"collect_fpn_proposals", {"FpnRois", "RoisNum"}},
88
    {"matrix_nms", {"Out", "Index", "RoisNum"}},
89 90
    {"distribute_fpn_proposals",
     {"MultiFpnRois", "RestoreIndex", "MultiLevelRoIsNum"}},
91
    {"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}},
92
    {"multiclass_nms3", {"Out", "NmsRoisNum"}},
93
    {"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
94
    {"momentum", {"ParamOut", "VelocityOut"}},
95
    {"rnn", {"DropoutState", "Reserve", "Out", "State"}},
L
Leo Chen 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109
};

// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
// generated in C++ automatically.
// However, some OPs need to pass the outputs from Python instead of generating
// them in C++. There are mainly 2 reasons for that,
// (1) Optimizer OPs need to update the input param in-place, like sgd.
//     So they need to pass the output which is same as input param.
// (2) Very few python APIs has out in their arguments, like fill_constant.
//     So they need to pass the python output to C++.
//     Actually, this is not a good design, since it may break the SSA graph,
//     especially in declarative mode.
// For those OPs, we need to manually specify the outs need to pass in this map.
std::map<std::string, std::set<std::string>> op_passing_outs_map = {
110 111 112
    {"sgd", {"ParamOut"}},
    {"adam",
     {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
113 114 115
    {"average_accumulates",
     {"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates",
      "out_old_num_accumulates", "out_num_updates"}},
116 117
    {"momentum", {"ParamOut", "VelocityOut"}},
    {"batch_norm", {"MeanOut", "VarianceOut"}},
C
ceci3 已提交
118
    {"sync_batch_norm", {"MeanOut", "VarianceOut"}},
119
    {"accuracy", {"Correct", "Total"}},
120
    {"fill_constant", {"Out"}},
L
Leo Chen 已提交
121
    {"matmul", {"Out"}},
122 123 124 125 126 127 128 129 130 131 132 133 134
    {"c_broadcast", {"Out"}},
    {"c_allreduce_sum", {"Out"}},
    {"c_allreduce_max", {"Out"}},
    {"c_allreduce_min", {"Out"}},
    {"c_allreduce_prod", {"Out"}},
    {"c_reduce_sum", {"Out"}},
    {"c_reduce_max", {"Out"}},
    {"c_reduce_min", {"Out"}},
    {"c_reduce_prod", {"Out"}},
    {"c_reduce", {"Out"}},
    {"c_allgather", {"Out"}},
    {"c_scatter", {"Out"}},
    {"barrier", {"Out"}},
L
Leo Chen 已提交
135
    {"fake_quantize_dequantize_moving_average_abs_max",
136
     {"Out", "OutScale", "OutAccum", "OutState"}},
137
    {"fake_quantize_dequantize_abs_max", {"Out", "OutScale"}},
H
huangxu96 已提交
138
    {"fake_channel_wise_quantize_dequantize_abs_max", {"Out", "OutScale"}},
139 140 141
    {"check_finite_and_unscale", {"Out", "FoundInfinite"}},
    {"update_loss_scaling",
     {"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}},
142
    {"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}},
143
    {"rnn", {"DropoutState"}},
L
Leo Chen 已提交
144
};
145

146 147 148 149 150 151 152 153 154 155 156 157 158
// NOTE(pangyoki): Tensor View Strategy.
// In this case, a new output varbase will be created, and this varbase will
// reuse the input varbase's allocation.
// It's a 2-layer map. The key of outer map is the view op name, the value is
// also a map which implies the mapping relationship between the output and
// input varbase.
std::map<std::string, std::pair<std::string, std::string>> view_op_map = {
    {"squeeze2", {"X", "Out"}},  // "X" -> "Out"
    {"unsqueeze2", {"X", "Out"}},
    {"reshape2", {"X", "Out"}},
    {"flatten_contiguous_range", {"X", "Out"}},
};

159
// clang-format off
160 161
const char* OUT_INITIALIZER_TEMPLATE =
    R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase(tracer->GenerateUniqueName()))}})";
162 163 164 165
const char* OUT_DUPLICABLE_INITIALIZER_TEMPLATE = R"({"%s", ConstructDuplicableOutput(%s)})";

const char* INPUT_INITIALIZER_TEMPLATE = R"({"%s", {%s}})";
const char* INPUT_LIST_INITIALIZER_TEMPLATE = R"({"%s", %s})";
L
Leo Chen 已提交
166 167 168 169 170

const char* INPUT_INITIALIZER_TEMPLATE_WITH_NULL = R"(	
    if (%s != nullptr) {	
      ins["%s"] = {%s};	
    }	
171
)";
L
Leo Chen 已提交
172 173 174 175 176 177 178

const char* INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"(	
    if (%s.size() != 0) {
      ins["%s"] = %s;	
    }	
)";

179 180
const char* OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL = R"(
    outs["%s"] = {%s};
181 182
)";

183 184
const char* OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"(
    outs["%s"] = %s;
L
Leo Chen 已提交
185
)";
186 187 188 189
// if inputs is list, no need {}
const char* ARG_OUT_NUM = R"(%sNum)";
const char* ARG_OUT_NUM_TYPE = R"(size_t )";

190 191 192 193 194 195 196
const char* IN_VAR_TYPE = R"(py::handle)";
const char* IN_VAR_LIST_TYPE = R"(py::handle)";

const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";

const char* CAST_VAR_TEMPLATE = R"(
197
  auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s, %s);)";
198 199

const char* CAST_VAR_LIST_TEMPLATE = R"(
200
  auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s, %s);)";
201 202


203 204 205 206 207 208 209 210 211 212
const char* ARG_TEMPLATE = R"(const %s& %s)";

const char* RETURN_TUPLE_TYPE = R"(std::tuple<%s>)";
const char* RETURN_TYPE = R"(%s)";
const char* RETURN_TUPLE_TEMPLATE = R"(std::make_tuple(%s))";
const char* RETURN_LIST_TEMPLATE = R"(outs["%s"])";
const char* RETURN_TEMPLATE = R"(outs["%s"][0])";

const char* FUNCTION_ARGS = R"(%s, const py::args& args)";
const char* FUNCTION_ARGS_NO_INPUT = R"(const py::args& args)";
213

214 215 216 217 218
const char* HandleViewBetweenInputAndOutput = R"(
    if (ins.count("%s") && outs.count("%s")) {
      HandleViewBetweenInputAndOutput(ins["%s"][0], outs["%s"][0]);
    })";

219
const char* OP_FUNCTION_TEMPLATE =
220
R"(
221
%s %s(%s)
222
{
223
  %s
224
  framework::AttributeMap attrs;
225
  ConstructAttrMapFromPyArgs("%s", %d, &attrs, args);
226 227 228 229 230 231 232 233 234
  {
    py::gil_scoped_release release;
    auto tracer = imperative::GetCurrentTracer();
    imperative::NameVarBaseMap outs = %s;
    imperative::NameVarBaseMap ins = %s;
    %s
    tracer->TraceOp("%s", ins, outs, attrs);
    return %s; 
  }   
235
})";
236

237
const char* PYBIND_ITEM_TEMPLATE = R"(  %s.def("%s", &%s);)";
238

239
// clang-format on
L
Leo Chen 已提交
240 241
static inline bool FindInsMap(const std::string& op_type,
                              const std::string& in_name) {
242 243 244
  return op_ins_map[op_type].count(in_name);
}

L
Leo Chen 已提交
245 246 247 248 249 250 251 252
static inline bool FindOutsMap(const std::string& op_type,
                               const std::string& out_name) {
  return op_outs_map[op_type].count(out_name);
}

static inline bool FindPassingOutsMap(const std::string& op_type,
                                      const std::string& out_name) {
  return op_passing_outs_map[op_type].count(out_name);
253
}
254

255 256 257 258
static inline bool FindViewOpMap(const std::string& op_type) {
  return view_op_map.count(op_type);
}

259 260 261 262
static inline std::string TempName(const std::string& name) {
  return name + '_';
}

263 264
static std::tuple<std::vector<std::string>, std::vector<std::string>>
GenerateOpFunctions(const std::string& module_name) {
265 266
  auto& op_info_map = paddle::framework::OpInfoMap::Instance().map();

267
  std::vector<std::string> op_function_list, bind_function_list;
268 269
  auto& all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels();

270 271 272 273 274 275 276
  for (auto& pair : op_info_map) {
    auto& op_info = pair.second;
    auto op_proto = op_info.proto_;
    if (op_proto == nullptr) {
      continue;
    }
    auto& op_type = op_proto->type();
277 278 279 280 281 282 283 284 285
    // Skip ooerator which is not inherit form OperatorWithKernel, like while,
    // since only OperatorWithKernel can run in dygraph mode.
    if (!all_kernels.count(op_type)) {
      continue;
    }
    std::string input_args = "";
    std::string ins_initializer = "{";
    std::string ins_initializer_with_null = "";
    std::string py_arg = "";
286
    int arg_idx = 0;
287
    int input_args_num = 0;
288
    std::string ins_cast_str = "";
289
    std::string view_strategy_str = "";
290 291 292
    for (auto& input : op_proto->inputs()) {
      auto& in_name = input.name();
      // skip those dispensable inputs, like ResidualData in conv2d
L
Leo Chen 已提交
293
      if (input.dispensable() && !FindInsMap(op_type, in_name)) {
294 295
        continue;
      }
296 297 298
      const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE;
      auto input_arg =
          paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name));
299 300
      input_args += input_arg;
      input_args += ",";
301
      input_args_num++;
302 303
      const auto in_cast_type =
          input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
304
      auto dispensable = input.dispensable() ? "true" : "false";
305 306
      ins_cast_str +=
          paddle::string::Sprintf(in_cast_type, in_name, op_type, in_name,
307
                                  arg_idx++, TempName(in_name), dispensable);
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331

      if (input.dispensable()) {
        const auto in_template = input.duplicable()
                                     ? INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
                                     : INPUT_INITIALIZER_TEMPLATE_WITH_NULL;
        ins_initializer_with_null +=
            paddle::string::Sprintf(in_template, in_name, in_name, in_name);
      } else {
        const auto in_template = input.duplicable()
                                     ? INPUT_LIST_INITIALIZER_TEMPLATE
                                     : INPUT_INITIALIZER_TEMPLATE;
        ins_initializer +=
            paddle::string::Sprintf(in_template, in_name, in_name);
        ins_initializer += ",";
      }
    }
    if (ins_initializer.back() == ',') {
      ins_initializer.pop_back();
    }
    ins_initializer += "}";

    if (input_args.back() == ',') {
      input_args.pop_back();
    }
332 333 334

    // Generate outs initializer
    std::string outs_initializer = "{";
L
Leo Chen 已提交
335
    std::string outs_initializer_with_null = "";
336 337
    std::string return_type = "";
    std::string return_str = "";
338

339
    int outs_num = 0;
340
    for (auto& output : op_proto->outputs()) {
L
Leo Chen 已提交
341 342 343
      auto& out_name = output.name();
      // skip those dispensable oututs
      if (output.dispensable() && !FindOutsMap(op_type, out_name)) {
344 345
        continue;
      }
346 347
      const auto out_type =
          output.duplicable() ? OUT_VAR_LIST_TYPE : OUT_VAR_TYPE;
348 349
      const auto return_template =
          output.duplicable() ? RETURN_LIST_TEMPLATE : RETURN_TEMPLATE;
L
Leo Chen 已提交
350
      if (FindPassingOutsMap(op_type, out_name)) {
351 352 353 354 355
        if (input_args != "") {
          input_args += ",";
        }
        input_args += out_type;
        input_args += out_name;
356
        input_args_num++;
L
Leo Chen 已提交
357 358 359 360 361

        if (output.dispensable()) {
          const auto out_template =
              output.duplicable() ? OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
                                  : OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL;
362 363
          outs_initializer_with_null +=
              paddle::string::Sprintf(out_template, out_name, out_name);
L
Leo Chen 已提交
364 365 366 367 368 369 370 371
        } else {
          const auto out_template = output.duplicable()
                                        ? INPUT_LIST_INITIALIZER_TEMPLATE
                                        : INPUT_INITIALIZER_TEMPLATE;
          outs_initializer +=
              paddle::string::Sprintf(out_template, out_name, out_name);
          outs_initializer += ",";
        }
372 373 374 375 376 377 378 379 380 381 382
      } else {
        // There are few Operators that have duplicable output, like `Out` in
        // split op. We need to specify the number of variables for the
        // duplicable output, as the argument OutNum;
        if (output.duplicable()) {
          if (input_args != "") {
            input_args += ",";
          }
          auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name);
          input_args += ARG_OUT_NUM_TYPE;
          input_args += out_num_str;
383
          input_args_num++;
L
Leo Chen 已提交
384
          outs_initializer += paddle::string::Sprintf(
385 386
              OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str);
        } else {
L
Leo Chen 已提交
387
          outs_initializer +=
388 389
              paddle::string::Sprintf(OUT_INITIALIZER_TEMPLATE, out_name);
        }
L
Leo Chen 已提交
390
        outs_initializer += ",";
391 392 393 394 395 396 397
      }

      return_type += out_type;
      return_type += ",";
      return_str += paddle::string::Sprintf(return_template, out_name);
      return_str += ",";
      outs_num += 1;
398 399 400
    }
    if (outs_initializer.back() == ',') {
      outs_initializer.pop_back();
401 402
      return_type.pop_back();
      return_str.pop_back();
403 404
    }
    outs_initializer += "}";
405 406 407 408 409 410 411
    if (FindViewOpMap(op_type)) {
      std::string viwe_input_name = view_op_map[op_type].first;
      std::string viwe_output_name = view_op_map[op_type].second;
      view_strategy_str += paddle::string::Sprintf(
          HandleViewBetweenInputAndOutput, viwe_input_name, viwe_output_name,
          viwe_input_name, viwe_output_name);
    }
412 413 414 415 416 417 418 419 420
    if (outs_num == 0) {
      return_type = "void";
    }
    if (outs_num > 1) {
      return_str = paddle::string::Sprintf(RETURN_TUPLE_TEMPLATE, return_str);
      return_type = paddle::string::Sprintf(RETURN_TUPLE_TYPE, return_type);
    }
    std::string function_args = "";
    if (input_args == "") {
421
      function_args = FUNCTION_ARGS_NO_INPUT;
422 423 424
    } else {
      function_args = paddle::string::Sprintf(FUNCTION_ARGS, input_args);
    }
425

426
    std::string func_name = "imperative_" + op_type;
427
    // generate op funtcion body
428
    auto op_function_str = paddle::string::Sprintf(
429
        OP_FUNCTION_TEMPLATE, return_type, func_name, function_args,
430
        ins_cast_str, op_type, input_args_num, outs_initializer,
431 432
        ins_initializer, ins_initializer_with_null +
                             outs_initializer_with_null + view_strategy_str,
433
        op_type, return_str);
434 435

    // generate pybind item
436 437 438 439 440
    auto bind_function_str = paddle::string::Sprintf(
        PYBIND_ITEM_TEMPLATE, module_name, op_type, func_name);

    op_function_list.emplace_back(std::move(op_function_str));
    bind_function_list.emplace_back(std::move(bind_function_str));
441
  }
442
  return std::make_tuple(op_function_list, bind_function_list);
443 444 445 446 447 448 449 450
}

int main(int argc, char* argv[]) {
  if (argc != 2) {
    std::cerr << "argc must be 2" << std::endl;
    return -1;
  }

451 452 453 454 455
#ifdef PADDLE_WITH_ASCEND
  auto ascend_ptr = paddle::framework::AscendInstance::GetInstance();
  ascend_ptr->InitGEForUT();
#endif

456 457 458 459 460 461 462 463 464 465
  std::vector<std::string> headers{"\"paddle/fluid/imperative/tracer.h\""};

  std::ofstream out(argv[1], std::ios::out);

  out << "#pragma once\n\n";

  for (auto& header : headers) {
    out << "#include  " + header + "\n";
  }

466 467
  auto op_funcs = GenerateOpFunctions("m");

468 469 470
  out << "namespace py = pybind11;"
      << "\n";
  out << "namespace paddle {\n"
471 472 473
      << "namespace pybind {\n";
  out << paddle::string::join_strings(std::get<0>(op_funcs), '\n');
  out << "\n\n";
474

475 476
  out << "inline void BindOpFunctions(pybind11::module *module) {\n"
      << "  auto m = module->def_submodule(\"ops\");\n\n";
477

478 479
  out << paddle::string::join_strings(std::get<1>(op_funcs), '\n');
  out << "\n";
480 481 482 483 484
  out << "}\n\n"
      << "} // namespace pybind\n"
      << "} // namespace paddle\n";

  out.close();
485 486 487 488

#ifdef PADDLE_WITH_ASCEND
  ge::GEFinalize();
#endif
489 490
  return 0;
}