未验证 提交 d980d251 编写于 作者: L Leo Chen 提交者: GitHub

specify outs, test=develop (#24537)

上级 16817c70
...@@ -24,13 +24,48 @@ ...@@ -24,13 +24,48 @@
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
// NOTE(zhiqiu): Commonly, the inputs in auto-generated OP function are
// determined by the OP`s proto automatically, i.e., all the inputs registered
// in OpMaker.
// However, some OPs have dispensable inputs, which means the input can
// be none for some conditions. It is discovered that most dispensable inputs
// is not used in imperative mode, so we drop those inputs when generating OP
// functions. While, for very few OPs, the dispensable inputs are used, we
// need to manually specify them in this map.
std::map<std::string, std::set<std::string>> op_ins_map = { std::map<std::string, std::set<std::string>> op_ins_map = {
{"layer_norm", {"X", "Scale", "Bias"}}, {"layer_norm", {"X", "Scale", "Bias"}},
{"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}}, {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}},
{"label_smooth", {"X", "PriorDist"}}, {"label_smooth", {"X", "PriorDist"}},
{"assign", {"X"}}, {"assign", {"X"}},
{"fake_quantize_dequantize_moving_average_abs_max",
{"X", "InScale", "InAccum", "InState"}},
}; };
std::map<std::string, std::set<std::string>> op_passing_out_map = {
// NOTE(zhiqiu): Like op_ins_map.
// Commonly, the outputs in auto-generated OP function are determined by the
// OP`s proto automatically, i.e., all the outputs registered in OpMaker.
// However, some OPs have dispensable outputs, which means the output can
// be none for some conditions. It is discovered that most dispensable outputs
// is not used in imperative mode, so we drop those outputs when generating OP
// functions. While, for very few OPs, the dispensable outputs are used, we
// need to manually specify them in this map.
std::map<std::string, std::set<std::string>> op_outs_map = {
{"fake_quantize_dequantize_moving_average_abs_max",
{"Out", "OutScale", "OutAccum", "OutState"}},
};
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
// generated in C++ automatically.
// However, some OPs need to pass the outputs from Python instead of generating
// them in C++. There are mainly 2 reasons for that,
// (1) Optimizer OPs need to update the input param in-place, like sgd.
// So they need to pass the output which is same as input param.
// (2) Very few python APIs has out in their arguments, like fill_constant.
// So they need to pass the python output to C++.
// Actually, this is not a good design, since it may break the SSA graph,
// especially in declarative mode.
// For those OPs, we need to manually specify the outs need to pass in this map.
std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"sgd", {"ParamOut"}}, {"sgd", {"ParamOut"}},
{"adam", {"adam",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
...@@ -38,7 +73,10 @@ std::map<std::string, std::set<std::string>> op_passing_out_map = { ...@@ -38,7 +73,10 @@ std::map<std::string, std::set<std::string>> op_passing_out_map = {
{"batch_norm", {"MeanOut", "VarianceOut"}}, {"batch_norm", {"MeanOut", "VarianceOut"}},
{"accuracy", {"Correct", "Total"}}, {"accuracy", {"Correct", "Total"}},
{"fill_constant", {"Out"}}, {"fill_constant", {"Out"}},
{"matmul", {"Out"}}}; {"matmul", {"Out"}},
{"fake_quantize_dequantize_moving_average_abs_max",
{"OutScale", "OutAccum", "OutState"}},
};
// clang-format off // clang-format off
const char* OUT_INITIALIZER_TEMPLATE = const char* OUT_INITIALIZER_TEMPLATE =
...@@ -47,17 +85,30 @@ const char* OUT_DUPLICABLE_INITIALIZER_TEMPLATE = R"({"%s", ConstructDuplicableO ...@@ -47,17 +85,30 @@ const char* OUT_DUPLICABLE_INITIALIZER_TEMPLATE = R"({"%s", ConstructDuplicableO
const char* INPUT_INITIALIZER_TEMPLATE = R"({"%s", {%s}})"; const char* INPUT_INITIALIZER_TEMPLATE = R"({"%s", {%s}})";
const char* INPUT_LIST_INITIALIZER_TEMPLATE = R"({"%s", %s})"; const char* INPUT_LIST_INITIALIZER_TEMPLATE = R"({"%s", %s})";
const char* INPUT_INITIALIZER_TEMPLATE_WITH_NULL = R"(
if (%s != nullptr) { const char* INPUT_INITIALIZER_TEMPLATE_WITH_NULL = R"(
ins["%s"] = {%s}; if (%s != nullptr) {
} ins["%s"] = {%s};
}
)"; )";
const char* INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"(
if (%s != nullptr) { const char* INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"(
ins["%s"] = %s; if (%s.size() != 0) {
} ins["%s"] = %s;
}
)";
const char* OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL = R"(
if (%s != nullptr) {
outs["%s"] = {%s};
}
)"; )";
const char* OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"(
if (%s.size() != 0) {
outs["%s"] = %s;
}
)";
// if inputs is list, no need {} // if inputs is list, no need {}
const char* ARG_OUT_NUM = R"(%sNum)"; const char* ARG_OUT_NUM = R"(%sNum)";
const char* ARG_OUT_NUM_TYPE = R"(size_t )"; const char* ARG_OUT_NUM_TYPE = R"(size_t )";
...@@ -95,14 +146,19 @@ R"( ...@@ -95,14 +146,19 @@ R"(
const char* PYBIND_ITEM_TEMPLATE = R"( %s.def("%s", &%s);)"; const char* PYBIND_ITEM_TEMPLATE = R"( %s.def("%s", &%s);)";
// clang-format on // clang-format on
static inline bool FindInputInSpecialization(const std::string& op_type, static inline bool FindInsMap(const std::string& op_type,
const std::string& in_name) { const std::string& in_name) {
return op_ins_map[op_type].count(in_name); return op_ins_map[op_type].count(in_name);
} }
static inline bool FindOutoutInSpecialization(const std::string& op_type, static inline bool FindOutsMap(const std::string& op_type,
const std::string& out_name) { const std::string& out_name) {
return op_passing_out_map[op_type].count(out_name); return op_outs_map[op_type].count(out_name);
}
static inline bool FindPassingOutsMap(const std::string& op_type,
const std::string& out_name) {
return op_passing_outs_map[op_type].count(out_name);
} }
static std::tuple<std::vector<std::string>, std::vector<std::string>> static std::tuple<std::vector<std::string>, std::vector<std::string>>
...@@ -131,7 +187,7 @@ GenerateOpFunctions(const std::string& module_name) { ...@@ -131,7 +187,7 @@ GenerateOpFunctions(const std::string& module_name) {
for (auto& input : op_proto->inputs()) { for (auto& input : op_proto->inputs()) {
auto& in_name = input.name(); auto& in_name = input.name();
// skip those dispensable inputs, like ResidualData in conv2d // skip those dispensable inputs, like ResidualData in conv2d
if (input.dispensable() && !FindInputInSpecialization(op_type, in_name)) { if (input.dispensable() && !FindInsMap(op_type, in_name)) {
continue; continue;
} }
const auto in_type = input.duplicable() ? VAR_LIST_TYPE : VAR_TYPE; const auto in_type = input.duplicable() ? VAR_LIST_TYPE : VAR_TYPE;
...@@ -165,30 +221,41 @@ GenerateOpFunctions(const std::string& module_name) { ...@@ -165,30 +221,41 @@ GenerateOpFunctions(const std::string& module_name) {
// Generate outs initializer // Generate outs initializer
std::string outs_initializer = "{"; std::string outs_initializer = "{";
std::string outs_initializer_with_null = "";
std::string return_type = ""; std::string return_type = "";
std::string return_str = ""; std::string return_str = "";
int outs_num = 0; int outs_num = 0;
for (auto& output : op_proto->outputs()) { for (auto& output : op_proto->outputs()) {
if (output.dispensable()) { auto& out_name = output.name();
// skip those dispensable oututs
if (output.dispensable() && !FindOutsMap(op_type, out_name)) {
continue; continue;
} }
const auto out_type = output.duplicable() ? VAR_LIST_TYPE : VAR_TYPE; const auto out_type = output.duplicable() ? VAR_LIST_TYPE : VAR_TYPE;
const auto return_template = const auto return_template =
output.duplicable() ? RETURN_LIST_TEMPLATE : RETURN_TEMPLATE; output.duplicable() ? RETURN_LIST_TEMPLATE : RETURN_TEMPLATE;
auto& out_name = output.name(); if (FindPassingOutsMap(op_type, out_name)) {
std::string out_initializer_str;
if (FindOutoutInSpecialization(op_type, out_name)) {
if (input_args != "") { if (input_args != "") {
input_args += ","; input_args += ",";
} }
input_args += out_type; input_args += out_type;
input_args += out_name; input_args += out_name;
const auto out_template = output.duplicable()
? INPUT_LIST_INITIALIZER_TEMPLATE if (output.dispensable()) {
: INPUT_INITIALIZER_TEMPLATE; const auto out_template =
out_initializer_str += output.duplicable() ? OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
paddle::string::Sprintf(out_template, out_name, out_name); : OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL;
outs_initializer_with_null += paddle::string::Sprintf(
out_template, out_name, out_name, out_name);
} else {
const auto out_template = output.duplicable()
? INPUT_LIST_INITIALIZER_TEMPLATE
: INPUT_INITIALIZER_TEMPLATE;
outs_initializer +=
paddle::string::Sprintf(out_template, out_name, out_name);
outs_initializer += ",";
}
} else { } else {
// There are few Operators that have duplicable output, like `Out` in // There are few Operators that have duplicable output, like `Out` in
// split op. We need to specify the number of variables for the // split op. We need to specify the number of variables for the
...@@ -200,12 +267,13 @@ GenerateOpFunctions(const std::string& module_name) { ...@@ -200,12 +267,13 @@ GenerateOpFunctions(const std::string& module_name) {
auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name); auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name);
input_args += ARG_OUT_NUM_TYPE; input_args += ARG_OUT_NUM_TYPE;
input_args += out_num_str; input_args += out_num_str;
out_initializer_str = paddle::string::Sprintf( outs_initializer += paddle::string::Sprintf(
OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str); OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str);
} else { } else {
out_initializer_str = outs_initializer +=
paddle::string::Sprintf(OUT_INITIALIZER_TEMPLATE, out_name); paddle::string::Sprintf(OUT_INITIALIZER_TEMPLATE, out_name);
} }
outs_initializer += ",";
} }
return_type += out_type; return_type += out_type;
...@@ -213,9 +281,6 @@ GenerateOpFunctions(const std::string& module_name) { ...@@ -213,9 +281,6 @@ GenerateOpFunctions(const std::string& module_name) {
return_str += paddle::string::Sprintf(return_template, out_name); return_str += paddle::string::Sprintf(return_template, out_name);
return_str += ","; return_str += ",";
outs_num += 1; outs_num += 1;
outs_initializer += out_initializer_str;
outs_initializer += ",";
} }
if (outs_initializer.back() == ',') { if (outs_initializer.back() == ',') {
outs_initializer.pop_back(); outs_initializer.pop_back();
...@@ -241,7 +306,8 @@ GenerateOpFunctions(const std::string& module_name) { ...@@ -241,7 +306,8 @@ GenerateOpFunctions(const std::string& module_name) {
// generate op funtcion body // generate op funtcion body
auto op_function_str = paddle::string::Sprintf( auto op_function_str = paddle::string::Sprintf(
OP_FUNCTION_TEMPLATE, return_type, func_name, function_args, OP_FUNCTION_TEMPLATE, return_type, func_name, function_args,
outs_initializer, ins_initializer, ins_initializer_with_null, op_type, outs_initializer, ins_initializer,
ins_initializer_with_null + outs_initializer_with_null, op_type,
return_str); return_str);
// generate pybind item // generate pybind item
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册