diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 3154ed01867de5d5e8cdb6a1d8d76a713e531ba5..a9075a333a0f9fde8cd32861a1ed3c0ed350b7ba 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -24,13 +24,48 @@ #include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/string/string_helper.h" +// NOTE(zhiqiu): Commonly, the inputs in auto-generated OP function are +// determined by the OP`s proto automatically, i.e., all the inputs registered +// in OpMaker. +// However, some OPs have dispensable inputs, which means the input can +// be none for some conditions. It is discovered that most dispensable inputs +// is not used in imperative mode, so we drop those inputs when generating OP +// functions. While, for very few OPs, the dispensable inputs are used, we +// need to manually specify them in this map. std::map> op_ins_map = { {"layer_norm", {"X", "Scale", "Bias"}}, {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}}, {"label_smooth", {"X", "PriorDist"}}, {"assign", {"X"}}, + {"fake_quantize_dequantize_moving_average_abs_max", + {"X", "InScale", "InAccum", "InState"}}, }; -std::map> op_passing_out_map = { + +// NOTE(zhiqiu): Like op_ins_map. +// Commonly, the outputs in auto-generated OP function are determined by the +// OP`s proto automatically, i.e., all the outputs registered in OpMaker. +// However, some OPs have dispensable outputs, which means the output can +// be none for some conditions. It is discovered that most dispensable outputs +// is not used in imperative mode, so we drop those outputs when generating OP +// functions. While, for very few OPs, the dispensable outputs are used, we +// need to manually specify them in this map. +std::map> op_outs_map = { + {"fake_quantize_dequantize_moving_average_abs_max", + {"Out", "OutScale", "OutAccum", "OutState"}}, +}; + +// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are +// generated in C++ automatically. +// However, some OPs need to pass the outputs from Python instead of generating +// them in C++. There are mainly 2 reasons for that, +// (1) Optimizer OPs need to update the input param in-place, like sgd. +// So they need to pass the output which is same as input param. +// (2) Very few python APIs has out in their arguments, like fill_constant. +// So they need to pass the python output to C++. +// Actually, this is not a good design, since it may break the SSA graph, +// especially in declarative mode. +// For those OPs, we need to manually specify the outs need to pass in this map. +std::map> op_passing_outs_map = { {"sgd", {"ParamOut"}}, {"adam", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, @@ -38,7 +73,10 @@ std::map> op_passing_out_map = { {"batch_norm", {"MeanOut", "VarianceOut"}}, {"accuracy", {"Correct", "Total"}}, {"fill_constant", {"Out"}}, - {"matmul", {"Out"}}}; + {"matmul", {"Out"}}, + {"fake_quantize_dequantize_moving_average_abs_max", + {"OutScale", "OutAccum", "OutState"}}, +}; // clang-format off const char* OUT_INITIALIZER_TEMPLATE = @@ -47,17 +85,30 @@ const char* OUT_DUPLICABLE_INITIALIZER_TEMPLATE = R"({"%s", ConstructDuplicableO const char* INPUT_INITIALIZER_TEMPLATE = R"({"%s", {%s}})"; const char* INPUT_LIST_INITIALIZER_TEMPLATE = R"({"%s", %s})"; -const char* INPUT_INITIALIZER_TEMPLATE_WITH_NULL = R"( - if (%s != nullptr) { - ins["%s"] = {%s}; - } + +const char* INPUT_INITIALIZER_TEMPLATE_WITH_NULL = R"( + if (%s != nullptr) { + ins["%s"] = {%s}; + } )"; -const char* INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"( - if (%s != nullptr) { - ins["%s"] = %s; - } + +const char* INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"( + if (%s.size() != 0) { + ins["%s"] = %s; + } +)"; + +const char* OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL = R"( + if (%s != nullptr) { + outs["%s"] = {%s}; + } )"; +const char* OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"( + if (%s.size() != 0) { + outs["%s"] = %s; + } +)"; // if inputs is list, no need {} const char* ARG_OUT_NUM = R"(%sNum)"; const char* ARG_OUT_NUM_TYPE = R"(size_t )"; @@ -95,14 +146,19 @@ R"( const char* PYBIND_ITEM_TEMPLATE = R"( %s.def("%s", &%s);)"; // clang-format on -static inline bool FindInputInSpecialization(const std::string& op_type, - const std::string& in_name) { +static inline bool FindInsMap(const std::string& op_type, + const std::string& in_name) { return op_ins_map[op_type].count(in_name); } -static inline bool FindOutoutInSpecialization(const std::string& op_type, - const std::string& out_name) { - return op_passing_out_map[op_type].count(out_name); +static inline bool FindOutsMap(const std::string& op_type, + const std::string& out_name) { + return op_outs_map[op_type].count(out_name); +} + +static inline bool FindPassingOutsMap(const std::string& op_type, + const std::string& out_name) { + return op_passing_outs_map[op_type].count(out_name); } static std::tuple, std::vector> @@ -131,7 +187,7 @@ GenerateOpFunctions(const std::string& module_name) { for (auto& input : op_proto->inputs()) { auto& in_name = input.name(); // skip those dispensable inputs, like ResidualData in conv2d - if (input.dispensable() && !FindInputInSpecialization(op_type, in_name)) { + if (input.dispensable() && !FindInsMap(op_type, in_name)) { continue; } const auto in_type = input.duplicable() ? VAR_LIST_TYPE : VAR_TYPE; @@ -165,30 +221,41 @@ GenerateOpFunctions(const std::string& module_name) { // Generate outs initializer std::string outs_initializer = "{"; + std::string outs_initializer_with_null = ""; std::string return_type = ""; std::string return_str = ""; int outs_num = 0; for (auto& output : op_proto->outputs()) { - if (output.dispensable()) { + auto& out_name = output.name(); + // skip those dispensable oututs + if (output.dispensable() && !FindOutsMap(op_type, out_name)) { continue; } const auto out_type = output.duplicable() ? VAR_LIST_TYPE : VAR_TYPE; const auto return_template = output.duplicable() ? RETURN_LIST_TEMPLATE : RETURN_TEMPLATE; - auto& out_name = output.name(); - std::string out_initializer_str; - if (FindOutoutInSpecialization(op_type, out_name)) { + if (FindPassingOutsMap(op_type, out_name)) { if (input_args != "") { input_args += ","; } input_args += out_type; input_args += out_name; - const auto out_template = output.duplicable() - ? INPUT_LIST_INITIALIZER_TEMPLATE - : INPUT_INITIALIZER_TEMPLATE; - out_initializer_str += - paddle::string::Sprintf(out_template, out_name, out_name); + + if (output.dispensable()) { + const auto out_template = + output.duplicable() ? OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST + : OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL; + outs_initializer_with_null += paddle::string::Sprintf( + out_template, out_name, out_name, out_name); + } else { + const auto out_template = output.duplicable() + ? INPUT_LIST_INITIALIZER_TEMPLATE + : INPUT_INITIALIZER_TEMPLATE; + outs_initializer += + paddle::string::Sprintf(out_template, out_name, out_name); + outs_initializer += ","; + } } else { // There are few Operators that have duplicable output, like `Out` in // split op. We need to specify the number of variables for the @@ -200,12 +267,13 @@ GenerateOpFunctions(const std::string& module_name) { auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name); input_args += ARG_OUT_NUM_TYPE; input_args += out_num_str; - out_initializer_str = paddle::string::Sprintf( + outs_initializer += paddle::string::Sprintf( OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str); } else { - out_initializer_str = + outs_initializer += paddle::string::Sprintf(OUT_INITIALIZER_TEMPLATE, out_name); } + outs_initializer += ","; } return_type += out_type; @@ -213,9 +281,6 @@ GenerateOpFunctions(const std::string& module_name) { return_str += paddle::string::Sprintf(return_template, out_name); return_str += ","; outs_num += 1; - - outs_initializer += out_initializer_str; - outs_initializer += ","; } if (outs_initializer.back() == ',') { outs_initializer.pop_back(); @@ -241,7 +306,8 @@ GenerateOpFunctions(const std::string& module_name) { // generate op funtcion body auto op_function_str = paddle::string::Sprintf( OP_FUNCTION_TEMPLATE, return_type, func_name, function_args, - outs_initializer, ins_initializer, ins_initializer_with_null, op_type, + outs_initializer, ins_initializer, + ins_initializer_with_null + outs_initializer_with_null, op_type, return_str); // generate pybind item