未验证 提交 832e58d6 编写于 作者: J Jiabin Yang 提交者: GitHub

Fix stray error (#42509)

* fix @ stray error in dygraph

* fix @ stray error in dygraph
上级 06927016
......@@ -56,6 +56,13 @@ static std::unordered_set<std::string> black_ops_list = {"run_program"};
static std::string LegalizeVariableName(const std::string& var_name) {
std::string ret = var_name;
std::replace(ret.begin(), ret.end(), '-', '_'); // replace all '-' to '_'
std::replace(ret.begin(), ret.end(), '@', '_'); // replace all '-' to '_'
return ret;
}
static std::string LegalizeVarName(const std::string& var_name) {
std::string ret = var_name;
std::replace(ret.begin(), ret.end(), '@', '_'); // replace all '-' to '_'
return ret;
}
......@@ -1024,7 +1031,8 @@ static std::string GenerateGradNodeCreationContent(
// egr::EagerUtils::autograd_meta("op_proto.outputs()[0].name()")"
for (const proto::OpProto::Var& output : out_vars) {
const std::string& output_name = output.name();
const std::string& output_autograd_name = "p_autograd_" + output_name;
const std::string& output_autograd_name =
"p_autograd_" + LegalizeVarName(output_name);
// output autograd_meta should be got after running TraceOP.
if (output.duplicable()) {
......@@ -1032,12 +1040,13 @@ static std::string GenerateGradNodeCreationContent(
" std::vector<egr::AutogradMeta*> %s = "
"egr::EagerUtils::autograd_meta(&%s);\n";
get_output_autograd_meta_str += paddle::string::Sprintf(
GET_MULTI_AUTOGRAD_META_TEMPLATE, output_autograd_name, output_name);
GET_MULTI_AUTOGRAD_META_TEMPLATE, output_autograd_name,
LegalizeVarName(output_name));
} else {
// In inplace op, the case where output is duplicable is not considered.
// Replace output directly with input in inplace op.
if (!inplace_map.empty() && inplace_map.count(output_name)) {
auto inplace_input_name = inplace_map[output_name];
auto inplace_input_name = LegalizeVarName(inplace_map[output_name]);
const std::string& inplace_input_autograd_name =
"p_autograd_" + inplace_input_name;
const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE =
......@@ -1049,9 +1058,9 @@ static std::string GenerateGradNodeCreationContent(
const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE =
" egr::AutogradMeta* %s = "
"egr::EagerUtils::autograd_meta(&%s);\n";
get_output_autograd_meta_str +=
paddle::string::Sprintf(GET_SINGLE_AUTOGRAD_META_TEMPLATE,
output_autograd_name, output_name);
get_output_autograd_meta_str += paddle::string::Sprintf(
GET_SINGLE_AUTOGRAD_META_TEMPLATE, output_autograd_name,
LegalizeVarName(output_name));
}
}
}
......@@ -1061,28 +1070,32 @@ static std::string GenerateGradNodeCreationContent(
// inplace).
for (const proto::OpProto::Var& input : in_vars) {
const std::string& input_name = input.name();
const std::string& input_autograd_name = "p_autograd_" + input_name;
const std::string& input_autograd_name =
"p_autograd_" + LegalizeVarName(input_name);
if (input.duplicable()) {
const char* GET_MULTI_AUTOGRAD_META_TEMPLATE =
" std::vector<egr::AutogradMeta*> %s = "
"egr::EagerUtils::nullable_autograd_meta(%s);\n";
get_input_autograd_meta_str += paddle::string::Sprintf(
GET_MULTI_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name);
GET_MULTI_AUTOGRAD_META_TEMPLATE, input_autograd_name,
LegalizeVarName(input_name));
} else if (input.dispensable()) {
const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE =
" egr::AutogradMeta* %s = "
"egr::EagerUtils::nullable_autograd_meta(%s);\n";
get_input_autograd_meta_str += paddle::string::Sprintf(
GET_SINGLE_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name);
GET_SINGLE_AUTOGRAD_META_TEMPLATE, input_autograd_name,
LegalizeVarName(input_name));
} else {
const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE =
" egr::AutogradMeta* %s = "
"egr::EagerUtils::nullable_autograd_meta(%s);\n";
get_input_autograd_meta_str += paddle::string::Sprintf(
GET_SINGLE_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name);
GET_SINGLE_AUTOGRAD_META_TEMPLATE, input_autograd_name,
LegalizeVarName(input_name));
}
}
VLOG(6) << "Generated inputs autograd_meta";
......@@ -1096,7 +1109,7 @@ static std::string GenerateGradNodeCreationContent(
" egr::EagerUtils::CheckInplace(%s, p_autograd_%s, "
"require_any_grad);\n";
for (auto& inplace_pair : inplace_map) {
std::string inplace_name = inplace_pair.second;
std::string inplace_name = LegalizeVarName(inplace_pair.second);
check_inplace_str += paddle::string::Sprintf(CHECKING_INPLACE_TEMPLATE,
inplace_name, inplace_name);
}
......@@ -1159,12 +1172,12 @@ static std::string GenerateGradNodeCreationContent(
if (!inplace_map.empty() && inplace_map.count(tensor_wrapper_name)) {
auto inplace_input_name = inplace_map[tensor_wrapper_name];
grad_node_creation_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name,
inplace_input_name, full_reserved);
SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name),
LegalizeVarName(inplace_input_name), full_reserved);
} else {
grad_node_creation_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name,
tensor_wrapper_name, full_reserved);
SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name),
LegalizeVarName(tensor_wrapper_name), full_reserved);
}
}
}
......@@ -1176,7 +1189,8 @@ static std::string GenerateGradNodeCreationContent(
std::string compute_require_grad_args = "trace_backward";
for (const proto::OpProto::Var& input : in_vars) {
const std::string& input_name = input.name();
const std::string& input_autograd_name = "p_autograd_" + input_name;
const std::string& input_autograd_name =
"p_autograd_" + LegalizeVarName(input_name);
if (!input.duplicable()) {
compute_require_grad_args += ", " + input_autograd_name;
......@@ -1184,8 +1198,9 @@ static std::string GenerateGradNodeCreationContent(
const char* SET_GRAD_OUT_META_TEMPLATE =
" grad_node->SetGradOutMeta(%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_OUT_META_TEMPLATE, input_name, input_position);
grad_node_creation_str +=
paddle::string::Sprintf(SET_GRAD_OUT_META_TEMPLATE,
LegalizeVarName(input_name), input_position);
} else {
compute_require_grad_args += ", &" + input_autograd_name;
......@@ -1193,8 +1208,9 @@ static std::string GenerateGradNodeCreationContent(
const char* SET_GRAD_OUT_META_TEMPLATE =
" grad_node->SetGradOutMeta(%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_OUT_META_TEMPLATE, input_name, input_position);
grad_node_creation_str +=
paddle::string::Sprintf(SET_GRAD_OUT_META_TEMPLATE,
LegalizeVarName(input_name), input_position);
}
}
......@@ -1208,7 +1224,7 @@ static std::string GenerateGradNodeCreationContent(
if (!inplace_map.empty() && inplace_map.count(output_name)) {
auto inplace_input_name = inplace_map[output_name];
const std::string& inplace_input_autograd_name =
"p_autograd_" + inplace_input_name;
"p_autograd_" + LegalizeVarName(inplace_input_name);
size_t output_position = fwd_outputs_name_pos_map.at(output_name);
// Intermediate Tensor does not require SetHistory, nor RetainGrad
......@@ -1228,18 +1244,20 @@ static std::string GenerateGradNodeCreationContent(
const char* SET_GRAD_IN_META_TEMPLATE =
" grad_node->SetGradInMeta(%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_IN_META_TEMPLATE, inplace_input_name, output_position);
SET_GRAD_IN_META_TEMPLATE, LegalizeVarName(inplace_input_name),
output_position);
// Intermediate Tensor does not require CheckAndRetainGrad
if (!output.intermediate()) {
VLOG(6) << "Generated Call RetainGradForTensor";
const char* RETAIN_GRAD_TEMPLATE =
" egr::EagerUtils::CheckAndRetainGrad(%s);\n";
grad_node_creation_str +=
paddle::string::Sprintf(RETAIN_GRAD_TEMPLATE, inplace_input_name);
grad_node_creation_str += paddle::string::Sprintf(
RETAIN_GRAD_TEMPLATE, LegalizeVarName(inplace_input_name));
}
} else {
const std::string& output_autograd_name = "p_autograd_" + output_name;
const std::string& output_autograd_name =
"p_autograd_" + LegalizeVarName(output_name);
size_t output_position = fwd_outputs_name_pos_map.at(output_name);
// Intermediate Tensor does not require SetHistory, nor RetainGrad
......@@ -1261,7 +1279,8 @@ static std::string GenerateGradNodeCreationContent(
const char* SET_GRAD_IN_META_TEMPLATE =
" grad_node->SetGradInMeta(%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_IN_META_TEMPLATE, output_name, output_position);
SET_GRAD_IN_META_TEMPLATE, LegalizeVarName(output_name),
output_position);
} else {
pass_stop_gradient_args += ", " + output_autograd_name;
......@@ -1280,7 +1299,8 @@ static std::string GenerateGradNodeCreationContent(
const char* SET_GRAD_IN_META_TEMPLATE =
" grad_node->SetGradInMeta(%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_IN_META_TEMPLATE, output_name, output_position);
SET_GRAD_IN_META_TEMPLATE, LegalizeVarName(output_name),
output_position);
}
// Intermediate Tensor does not require CheckAndRetainGrad
......@@ -1288,8 +1308,8 @@ static std::string GenerateGradNodeCreationContent(
VLOG(6) << "Generated Call RetainGradForTensor";
const char* RETAIN_GRAD_TEMPLATE =
" egr::EagerUtils::CheckAndRetainGrad(%s);\n";
grad_node_creation_str +=
paddle::string::Sprintf(RETAIN_GRAD_TEMPLATE, output_name);
grad_node_creation_str += paddle::string::Sprintf(
RETAIN_GRAD_TEMPLATE, LegalizeVarName(output_name));
}
}
}
......@@ -1412,9 +1432,10 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if (input.duplicable()) {
const char* FWD_INS_ARG_TEMPLATE =
"const std::vector<paddle::experimental::Tensor>& %s";
input_args_str_list[input_position] =
paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name);
amp_function_call_args_str_list[input_position] = " NEW_" + input_name;
input_args_str_list[input_position] = paddle::string::Sprintf(
FWD_INS_ARG_TEMPLATE, LegalizeVarName(input_name));
amp_function_call_args_str_list[input_position] =
" NEW_" + LegalizeVarName(input_name);
core_ops_args_type_info[op_type][input_position] = "list";
} else {
......@@ -1433,9 +1454,10 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if (!flag_find_input_name) {
FWD_INS_ARG_TEMPLATE = "const paddle::experimental::Tensor& %s";
}
input_args_str_list[input_position] =
paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name);
amp_function_call_args_str_list[input_position] = " NEW_" + input_name;
input_args_str_list[input_position] = paddle::string::Sprintf(
FWD_INS_ARG_TEMPLATE, LegalizeVarName(input_name));
amp_function_call_args_str_list[input_position] =
" NEW_" + LegalizeVarName(input_name);
core_ops_args_type_info[op_type][input_position] = "tensor";
}
......@@ -1445,8 +1467,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
const char* FWD_INS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::TrySyncToVars(%s) },";
ins_contents_str += paddle::string::Sprintf(FWD_INS_CONTENT_TEMPLATE,
input_name, input_name);
ins_contents_str += paddle::string::Sprintf(
FWD_INS_CONTENT_TEMPLATE, input_name, LegalizeVarName(input_name));
if (input.duplicable()) {
const char* AMP_TENSORS_VECTOR_TEMPLATE = "%s,";
amp_tensors_vector_str +=
......@@ -1455,16 +1477,18 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
" auto NEW_%s = egr::AmpAutoCasts(\"%s\", %s, amp_dst_dtype, "
"\"%s\");\n";
amp_auto_cast_str += paddle::string::Sprintf(
AMP_AUTO_CAST_TEMPLATE, input_name, input_name, input_name, op_type);
AMP_AUTO_CAST_TEMPLATE, LegalizeVarName(input_name), input_name,
LegalizeVarName(input_name), op_type);
} else {
const char* AMP_TENSORS_VECTOR_TEMPLATE = "{%s},";
amp_tensors_vector_str +=
paddle::string::Sprintf(AMP_TENSORS_VECTOR_TEMPLATE, input_name);
amp_tensors_vector_str += paddle::string::Sprintf(
AMP_TENSORS_VECTOR_TEMPLATE, LegalizeVarName(input_name));
const char* AMP_AUTO_CAST_TEMPLATE =
" auto NEW_%s = egr::AmpAutoCast(\"%s\", %s, amp_dst_dtype, "
"\"%s\");\n";
amp_auto_cast_str += paddle::string::Sprintf(
AMP_AUTO_CAST_TEMPLATE, input_name, input_name, input_name, op_type);
AMP_AUTO_CAST_TEMPLATE, LegalizeVarName(input_name), input_name,
LegalizeVarName(input_name), op_type);
}
}
if (ins_contents_str.size() > 0)
......@@ -1500,35 +1524,41 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
" if(%s.size() > 0) "
"ins[\"%s\"] = egr::EagerUtils::TrySyncToVars(%s);\n";
dispensable_ins_contents_str += paddle::string::Sprintf(
FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name);
FWD_INS_CONTENT_TEMPLATE, LegalizeVarName(input_name), input_name,
LegalizeVarName(input_name));
const char* FWD_AMP_TENSORS_VECTOR_TEMPLATE =
" if(%s.size() > 0) "
"amp_tensors_vector.push_back(%s);\n";
dispensable_amp_tensors_vector_str += paddle::string::Sprintf(
FWD_AMP_TENSORS_VECTOR_TEMPLATE, input_name, input_name);
FWD_AMP_TENSORS_VECTOR_TEMPLATE, LegalizeVarName(input_name),
LegalizeVarName(input_name));
const char* DISPENSABLE_AMP_AUTO_CAST_TEMPLATE =
" auto NEW_%s = ((%s.size() > 0) ? egr::AmpAutoCasts(\"%s\", "
"%s, amp_dst_dtype, \"%s\") : %s);\n";
dispensable_amp_auto_cast_str += paddle::string::Sprintf(
DISPENSABLE_AMP_AUTO_CAST_TEMPLATE, input_name, input_name,
input_name, input_name, op_type, input_name);
DISPENSABLE_AMP_AUTO_CAST_TEMPLATE, LegalizeVarName(input_name),
LegalizeVarName(input_name), input_name,
LegalizeVarName(input_name), op_type, LegalizeVarName(input_name));
} else {
const char* FWD_INS_CONTENT_TEMPLATE =
" if(%s.initialized()) "
"ins[\"%s\"] = egr::EagerUtils::TrySyncToVars(%s);\n";
dispensable_ins_contents_str += paddle::string::Sprintf(
FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name);
FWD_INS_CONTENT_TEMPLATE, LegalizeVarName(input_name), input_name,
LegalizeVarName(input_name));
const char* FWD_AMP_TENSORS_VECTOR_TEMPLATE =
" if(%s.initialized()) "
"amp_tensors_vector.push_back({ %s });\n";
dispensable_amp_tensors_vector_str += paddle::string::Sprintf(
FWD_AMP_TENSORS_VECTOR_TEMPLATE, input_name, input_name);
FWD_AMP_TENSORS_VECTOR_TEMPLATE, LegalizeVarName(input_name),
LegalizeVarName(input_name));
const char* DISPENSABLE_AMP_AUTO_CAST_TEMPLATE =
" auto NEW_%s = ((%s.initialized()) ? egr::AmpAutoCast(\"%s\", "
"%s, amp_dst_dtype, \"%s\") : %s);\n";
dispensable_amp_auto_cast_str += paddle::string::Sprintf(
DISPENSABLE_AMP_AUTO_CAST_TEMPLATE, input_name, input_name,
input_name, input_name, op_type, input_name);
DISPENSABLE_AMP_AUTO_CAST_TEMPLATE, LegalizeVarName(input_name),
LegalizeVarName(input_name), input_name,
LegalizeVarName(input_name), op_type, LegalizeVarName(input_name));
}
}
}
......@@ -1550,18 +1580,18 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if (output.duplicable()) {
const char* FWD_NUM_ARG_TEMPLATE =
", std::vector<paddle::experimental::Tensor*>& %s";
std::string arg_str =
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name);
std::string arg_str = paddle::string::Sprintf(
FWD_NUM_ARG_TEMPLATE, LegalizeVarName(output_var_name));
dygraph_function_args_str += arg_str;
amp_function_call_args_str += (", " + output_var_name);
amp_function_call_args_str += (", " + LegalizeVarName(output_var_name));
core_ops_args_type_info[op_type].push_back("list");
} else {
const char* FWD_NUM_ARG_TEMPLATE = ", paddle::experimental::Tensor* %s";
std::string arg_str =
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name);
std::string arg_str = paddle::string::Sprintf(
FWD_NUM_ARG_TEMPLATE, LegalizeVarName(output_var_name));
dygraph_function_args_str += arg_str;
amp_function_call_args_str += (", " + output_var_name);
amp_function_call_args_str += (", " + LegalizeVarName(output_var_name));
core_ops_args_type_info[op_type].push_back("tensor");
}
......@@ -1577,8 +1607,9 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
} else {
const char* FWD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::TrySyncToVars(%s) },";
outs_contents_str += paddle::string::Sprintf(
FWD_OUTS_CONTENT_TEMPLATE, output_name, output_var_name);
outs_contents_str +=
paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, output_name,
LegalizeVarName(output_var_name));
}
core_ops_args_info[op_type].push_back(output_name);
......@@ -1773,7 +1804,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
std::vector<std::string> return_types(output_size);
for (const proto::OpProto::Var& output : out_vars) {
const std::string& output_name = output.name();
const std::string output_var_args_name = output_name + "Var";
const std::string output_var_args_name =
LegalizeVariableName(output_name + "Var");
std::string out_tensor_str;
size_t return_position = fwd_outputs_name_pos_map.at(output_name);
std::string output_varname = LegalizeVariableName(output_name);
......@@ -1837,9 +1869,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
" %s.bump_inplace_version();\n"
" VLOG(3) << \"Tensor(\" << %s.name() << \") uses Inplace "
"Strategy.\";\n";
out_tensor_str = paddle::string::Sprintf(
FWD_OUT_TENSOR_TEMPLATE, output_name, inplace_input_name,
inplace_input_name, inplace_input_name);
out_tensor_str =
paddle::string::Sprintf(FWD_OUT_TENSOR_TEMPLATE, output_name,
LegalizeVarName(inplace_input_name),
LegalizeVarName(inplace_input_name),
LegalizeVarName(inplace_input_name));
} else {
const char* FWD_OUT_TENSOR_TEMPLATE =
" paddle::experimental::Tensor %s;\n"
......@@ -1854,7 +1888,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if (!inplace_map.empty() && inplace_map.count(output_name)) {
// Replace output directly with input in inplace op.
return_contents[return_position] = inplace_map[output_name];
return_contents[return_position] =
LegalizeVarName(inplace_map[output_name]);
} else {
return_contents[return_position] = output_varname;
}
......
......@@ -36,6 +36,11 @@
// phi
#include "paddle/phi/kernels/declarations.h"
static std::string LegalizeVarName(const std::string& var_name) {
std::string ret = var_name;
std::replace(ret.begin(), ret.end(), '@', '_'); // replace all '-' to '_'
return ret;
}
// clang-format off
const char* OUT_INITIALIZER_TEMPLATE =
R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase("auto_"+std::to_string(VarBaseUniqueNameID++)+"_"))}})";
......@@ -185,18 +190,19 @@ std::string GenerateOpFunctionsBody(
continue;
}
const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE;
auto input_arg =
paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name));
auto input_arg = paddle::string::Sprintf(
ARG_TEMPLATE, in_type, TempName(LegalizeVarName(in_name)));
input_args += input_arg;
input_args += ",";
input_args_num++;
const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, op_type,
in_name, arg_idx++, dispensable);
ins_cast_str +=
paddle::string::Sprintf(in_cast_type, LegalizeVarName(in_name), op_type,
in_name, arg_idx++, dispensable);
call_api_str += in_name + ", ";
call_api_str += LegalizeVarName(in_name) + ", ";
}
if (!input_args.empty() && input_args.back() == ',') {
......@@ -224,7 +230,7 @@ std::string GenerateOpFunctionsBody(
input_args += ",";
}
input_args += out_type;
input_args += out_name;
input_args += LegalizeVarName(out_name);
input_args_num++;
if (output.dispensable()) {
......@@ -237,18 +243,19 @@ std::string GenerateOpFunctionsBody(
const auto out_template = output.duplicable()
? INPUT_LIST_INITIALIZER_TEMPLATE
: INPUT_INITIALIZER_TEMPLATE;
outs_initializer +=
paddle::string::Sprintf(out_template, out_name, out_name);
outs_initializer += paddle::string::Sprintf(out_template, out_name,
LegalizeVarName(out_name));
outs_initializer += ",";
}
const auto in_cast_type = output.duplicable() ? CAST_VAR_PTR_LIST_TEMPLATE
: CAST_VAR_PTR_TEMPLATE;
auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, op_type,
out_name, arg_idx++, dispensable);
ins_cast_str +=
paddle::string::Sprintf(in_cast_type, LegalizeVarName(out_name),
op_type, out_name, arg_idx++, dispensable);
call_api_str += out_name + ", ";
call_api_str += LegalizeVarName(out_name) + ", ";
} else {
// There are few Operators that have duplicable output, like `Out` in
// split op. We need to specify the number of variables for the
......@@ -257,7 +264,8 @@ std::string GenerateOpFunctionsBody(
if (input_args != "") {
input_args += ",";
}
auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name);
auto out_num_str =
paddle::string::Sprintf(ARG_OUT_NUM, LegalizeVarName(out_name));
input_args += ARG_OUT_NUM_TYPE;
input_args += out_num_str;
input_args_num++;
......
......@@ -35,6 +35,12 @@
// phi
#include "paddle/phi/kernels/declarations.h"
static std::string LegalizeVarName(const std::string& var_name) {
std::string ret = var_name;
std::replace(ret.begin(), ret.end(), '@', '_'); // replace all '-' to '_'
return ret;
}
// NOTE(pangyoki): Inplace OP with duplicable input.
// The set includes inplace ops that have duplicable input.
// The first Varbase in input needs to be specified for the inplace strategy
......@@ -201,28 +207,31 @@ std::string GenerateOpFunctionsBody(
continue;
}
const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE;
auto input_arg =
paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name));
auto input_arg = paddle::string::Sprintf(
ARG_TEMPLATE, in_type, LegalizeVarName(TempName(in_name)));
input_args += input_arg;
input_args += ",";
input_args_num++;
const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, in_name,
arg_idx++, dispensable);
ins_cast_str +=
paddle::string::Sprintf(in_cast_type, LegalizeVarName(in_name), in_name,
arg_idx++, dispensable);
if (input.dispensable()) {
const auto in_template = input.duplicable()
? INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
: INPUT_INITIALIZER_TEMPLATE_WITH_NULL;
ins_initializer_with_null +=
paddle::string::Sprintf(in_template, in_name, in_name, in_name);
paddle::string::Sprintf(in_template, LegalizeVarName(in_name),
in_name, LegalizeVarName(in_name));
} else {
const auto in_template = input.duplicable()
? INPUT_LIST_INITIALIZER_TEMPLATE
: INPUT_INITIALIZER_TEMPLATE;
ins_initializer += paddle::string::Sprintf(in_template, in_name, in_name);
ins_initializer += paddle::string::Sprintf(in_template, in_name,
LegalizeVarName(in_name));
ins_initializer += ",";
}
}
......@@ -259,7 +268,7 @@ std::string GenerateOpFunctionsBody(
input_args += ",";
}
input_args += out_type;
input_args += out_name;
input_args += LegalizeVarName(out_name);
input_args_num++;
if (output.dispensable()) {
......@@ -272,16 +281,17 @@ std::string GenerateOpFunctionsBody(
const auto out_template = output.duplicable()
? INPUT_LIST_INITIALIZER_TEMPLATE
: INPUT_INITIALIZER_TEMPLATE;
outs_initializer +=
paddle::string::Sprintf(out_template, out_name, out_name);
outs_initializer += paddle::string::Sprintf(out_template, out_name,
LegalizeVarName(out_name));
outs_initializer += ",";
}
const auto in_cast_type =
output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, out_name,
arg_idx++, dispensable);
ins_cast_str +=
paddle::string::Sprintf(in_cast_type, LegalizeVarName(out_name),
out_name, arg_idx++, dispensable);
} else if (use_inplace_strategy && inplace_map.count(out_name)) {
PADDLE_ENFORCE_NE(
inplace_map[out_name], "",
......@@ -307,11 +317,13 @@ std::string GenerateOpFunctionsBody(
// Leaf Var that doesn't stop gradient can't use inplace strategy.
// Increase inplace_version.
inplace_strategy_str += paddle::string::Sprintf(
INPLACE_STRATEGY_TEMPLATE, inplace_input_name, inplace_input_name,
INPLACE_LEAF_ERROR_MESSAGE, inplace_input_name, inplace_input_name,
inplace_input_name);
outs_initializer +=
paddle::string::Sprintf(out_template, out_name, inplace_input_name);
INPLACE_STRATEGY_TEMPLATE, LegalizeVarName(inplace_input_name),
LegalizeVarName(inplace_input_name), INPLACE_LEAF_ERROR_MESSAGE,
LegalizeVarName(inplace_input_name),
LegalizeVarName(inplace_input_name),
LegalizeVarName(inplace_input_name));
outs_initializer += paddle::string::Sprintf(
out_template, out_name, LegalizeVarName(inplace_input_name));
outs_initializer += ",";
} else {
// There are few Operators that have duplicable output, like `Out` in
......@@ -321,7 +333,8 @@ std::string GenerateOpFunctionsBody(
if (input_args != "") {
input_args += ",";
}
auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name);
auto out_num_str =
paddle::string::Sprintf(ARG_OUT_NUM, LegalizeVarName(out_name));
input_args += ARG_OUT_NUM_TYPE;
input_args += out_num_str;
input_args_num++;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册