未验证 提交 832e58d6 编写于 作者: J Jiabin Yang 提交者: GitHub

Fix stray error (#42509)

* fix @ stray error in dygraph

* fix @ stray error in dygraph
上级 06927016
...@@ -56,6 +56,13 @@ static std::unordered_set<std::string> black_ops_list = {"run_program"}; ...@@ -56,6 +56,13 @@ static std::unordered_set<std::string> black_ops_list = {"run_program"};
static std::string LegalizeVariableName(const std::string& var_name) { static std::string LegalizeVariableName(const std::string& var_name) {
std::string ret = var_name; std::string ret = var_name;
std::replace(ret.begin(), ret.end(), '-', '_'); // replace all '-' to '_' std::replace(ret.begin(), ret.end(), '-', '_'); // replace all '-' to '_'
std::replace(ret.begin(), ret.end(), '@', '_'); // replace all '-' to '_'
return ret;
}
static std::string LegalizeVarName(const std::string& var_name) {
std::string ret = var_name;
std::replace(ret.begin(), ret.end(), '@', '_'); // replace all '-' to '_'
return ret; return ret;
} }
...@@ -1024,7 +1031,8 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1024,7 +1031,8 @@ static std::string GenerateGradNodeCreationContent(
// egr::EagerUtils::autograd_meta("op_proto.outputs()[0].name()")" // egr::EagerUtils::autograd_meta("op_proto.outputs()[0].name()")"
for (const proto::OpProto::Var& output : out_vars) { for (const proto::OpProto::Var& output : out_vars) {
const std::string& output_name = output.name(); const std::string& output_name = output.name();
const std::string& output_autograd_name = "p_autograd_" + output_name; const std::string& output_autograd_name =
"p_autograd_" + LegalizeVarName(output_name);
// output autograd_meta should be got after running TraceOP. // output autograd_meta should be got after running TraceOP.
if (output.duplicable()) { if (output.duplicable()) {
...@@ -1032,12 +1040,13 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1032,12 +1040,13 @@ static std::string GenerateGradNodeCreationContent(
" std::vector<egr::AutogradMeta*> %s = " " std::vector<egr::AutogradMeta*> %s = "
"egr::EagerUtils::autograd_meta(&%s);\n"; "egr::EagerUtils::autograd_meta(&%s);\n";
get_output_autograd_meta_str += paddle::string::Sprintf( get_output_autograd_meta_str += paddle::string::Sprintf(
GET_MULTI_AUTOGRAD_META_TEMPLATE, output_autograd_name, output_name); GET_MULTI_AUTOGRAD_META_TEMPLATE, output_autograd_name,
LegalizeVarName(output_name));
} else { } else {
// In inplace op, the case where output is duplicable is not considered. // In inplace op, the case where output is duplicable is not considered.
// Replace output directly with input in inplace op. // Replace output directly with input in inplace op.
if (!inplace_map.empty() && inplace_map.count(output_name)) { if (!inplace_map.empty() && inplace_map.count(output_name)) {
auto inplace_input_name = inplace_map[output_name]; auto inplace_input_name = LegalizeVarName(inplace_map[output_name]);
const std::string& inplace_input_autograd_name = const std::string& inplace_input_autograd_name =
"p_autograd_" + inplace_input_name; "p_autograd_" + inplace_input_name;
const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE = const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE =
...@@ -1049,9 +1058,9 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1049,9 +1058,9 @@ static std::string GenerateGradNodeCreationContent(
const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE = const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE =
" egr::AutogradMeta* %s = " " egr::AutogradMeta* %s = "
"egr::EagerUtils::autograd_meta(&%s);\n"; "egr::EagerUtils::autograd_meta(&%s);\n";
get_output_autograd_meta_str += get_output_autograd_meta_str += paddle::string::Sprintf(
paddle::string::Sprintf(GET_SINGLE_AUTOGRAD_META_TEMPLATE, GET_SINGLE_AUTOGRAD_META_TEMPLATE, output_autograd_name,
output_autograd_name, output_name); LegalizeVarName(output_name));
} }
} }
} }
...@@ -1061,28 +1070,32 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1061,28 +1070,32 @@ static std::string GenerateGradNodeCreationContent(
// inplace). // inplace).
for (const proto::OpProto::Var& input : in_vars) { for (const proto::OpProto::Var& input : in_vars) {
const std::string& input_name = input.name(); const std::string& input_name = input.name();
const std::string& input_autograd_name = "p_autograd_" + input_name; const std::string& input_autograd_name =
"p_autograd_" + LegalizeVarName(input_name);
if (input.duplicable()) { if (input.duplicable()) {
const char* GET_MULTI_AUTOGRAD_META_TEMPLATE = const char* GET_MULTI_AUTOGRAD_META_TEMPLATE =
" std::vector<egr::AutogradMeta*> %s = " " std::vector<egr::AutogradMeta*> %s = "
"egr::EagerUtils::nullable_autograd_meta(%s);\n"; "egr::EagerUtils::nullable_autograd_meta(%s);\n";
get_input_autograd_meta_str += paddle::string::Sprintf( get_input_autograd_meta_str += paddle::string::Sprintf(
GET_MULTI_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name); GET_MULTI_AUTOGRAD_META_TEMPLATE, input_autograd_name,
LegalizeVarName(input_name));
} else if (input.dispensable()) { } else if (input.dispensable()) {
const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE = const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE =
" egr::AutogradMeta* %s = " " egr::AutogradMeta* %s = "
"egr::EagerUtils::nullable_autograd_meta(%s);\n"; "egr::EagerUtils::nullable_autograd_meta(%s);\n";
get_input_autograd_meta_str += paddle::string::Sprintf( get_input_autograd_meta_str += paddle::string::Sprintf(
GET_SINGLE_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name); GET_SINGLE_AUTOGRAD_META_TEMPLATE, input_autograd_name,
LegalizeVarName(input_name));
} else { } else {
const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE = const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE =
" egr::AutogradMeta* %s = " " egr::AutogradMeta* %s = "
"egr::EagerUtils::nullable_autograd_meta(%s);\n"; "egr::EagerUtils::nullable_autograd_meta(%s);\n";
get_input_autograd_meta_str += paddle::string::Sprintf( get_input_autograd_meta_str += paddle::string::Sprintf(
GET_SINGLE_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name); GET_SINGLE_AUTOGRAD_META_TEMPLATE, input_autograd_name,
LegalizeVarName(input_name));
} }
} }
VLOG(6) << "Generated inputs autograd_meta"; VLOG(6) << "Generated inputs autograd_meta";
...@@ -1096,7 +1109,7 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1096,7 +1109,7 @@ static std::string GenerateGradNodeCreationContent(
" egr::EagerUtils::CheckInplace(%s, p_autograd_%s, " " egr::EagerUtils::CheckInplace(%s, p_autograd_%s, "
"require_any_grad);\n"; "require_any_grad);\n";
for (auto& inplace_pair : inplace_map) { for (auto& inplace_pair : inplace_map) {
std::string inplace_name = inplace_pair.second; std::string inplace_name = LegalizeVarName(inplace_pair.second);
check_inplace_str += paddle::string::Sprintf(CHECKING_INPLACE_TEMPLATE, check_inplace_str += paddle::string::Sprintf(CHECKING_INPLACE_TEMPLATE,
inplace_name, inplace_name); inplace_name, inplace_name);
} }
...@@ -1159,12 +1172,12 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1159,12 +1172,12 @@ static std::string GenerateGradNodeCreationContent(
if (!inplace_map.empty() && inplace_map.count(tensor_wrapper_name)) { if (!inplace_map.empty() && inplace_map.count(tensor_wrapper_name)) {
auto inplace_input_name = inplace_map[tensor_wrapper_name]; auto inplace_input_name = inplace_map[tensor_wrapper_name];
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name, SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name),
inplace_input_name, full_reserved); LegalizeVarName(inplace_input_name), full_reserved);
} else { } else {
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name, SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name),
tensor_wrapper_name, full_reserved); LegalizeVarName(tensor_wrapper_name), full_reserved);
} }
} }
} }
...@@ -1176,7 +1189,8 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1176,7 +1189,8 @@ static std::string GenerateGradNodeCreationContent(
std::string compute_require_grad_args = "trace_backward"; std::string compute_require_grad_args = "trace_backward";
for (const proto::OpProto::Var& input : in_vars) { for (const proto::OpProto::Var& input : in_vars) {
const std::string& input_name = input.name(); const std::string& input_name = input.name();
const std::string& input_autograd_name = "p_autograd_" + input_name; const std::string& input_autograd_name =
"p_autograd_" + LegalizeVarName(input_name);
if (!input.duplicable()) { if (!input.duplicable()) {
compute_require_grad_args += ", " + input_autograd_name; compute_require_grad_args += ", " + input_autograd_name;
...@@ -1184,8 +1198,9 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1184,8 +1198,9 @@ static std::string GenerateGradNodeCreationContent(
const char* SET_GRAD_OUT_META_TEMPLATE = const char* SET_GRAD_OUT_META_TEMPLATE =
" grad_node->SetGradOutMeta(%s, %d);\n"; " grad_node->SetGradOutMeta(%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str +=
SET_GRAD_OUT_META_TEMPLATE, input_name, input_position); paddle::string::Sprintf(SET_GRAD_OUT_META_TEMPLATE,
LegalizeVarName(input_name), input_position);
} else { } else {
compute_require_grad_args += ", &" + input_autograd_name; compute_require_grad_args += ", &" + input_autograd_name;
...@@ -1193,8 +1208,9 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1193,8 +1208,9 @@ static std::string GenerateGradNodeCreationContent(
const char* SET_GRAD_OUT_META_TEMPLATE = const char* SET_GRAD_OUT_META_TEMPLATE =
" grad_node->SetGradOutMeta(%s, %d);\n"; " grad_node->SetGradOutMeta(%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str +=
SET_GRAD_OUT_META_TEMPLATE, input_name, input_position); paddle::string::Sprintf(SET_GRAD_OUT_META_TEMPLATE,
LegalizeVarName(input_name), input_position);
} }
} }
...@@ -1208,7 +1224,7 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1208,7 +1224,7 @@ static std::string GenerateGradNodeCreationContent(
if (!inplace_map.empty() && inplace_map.count(output_name)) { if (!inplace_map.empty() && inplace_map.count(output_name)) {
auto inplace_input_name = inplace_map[output_name]; auto inplace_input_name = inplace_map[output_name];
const std::string& inplace_input_autograd_name = const std::string& inplace_input_autograd_name =
"p_autograd_" + inplace_input_name; "p_autograd_" + LegalizeVarName(inplace_input_name);
size_t output_position = fwd_outputs_name_pos_map.at(output_name); size_t output_position = fwd_outputs_name_pos_map.at(output_name);
// Intermediate Tensor does not require SetHistory, nor RetainGrad // Intermediate Tensor does not require SetHistory, nor RetainGrad
...@@ -1228,18 +1244,20 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1228,18 +1244,20 @@ static std::string GenerateGradNodeCreationContent(
const char* SET_GRAD_IN_META_TEMPLATE = const char* SET_GRAD_IN_META_TEMPLATE =
" grad_node->SetGradInMeta(%s, %d);\n"; " grad_node->SetGradInMeta(%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_IN_META_TEMPLATE, inplace_input_name, output_position); SET_GRAD_IN_META_TEMPLATE, LegalizeVarName(inplace_input_name),
output_position);
// Intermediate Tensor does not require CheckAndRetainGrad // Intermediate Tensor does not require CheckAndRetainGrad
if (!output.intermediate()) { if (!output.intermediate()) {
VLOG(6) << "Generated Call RetainGradForTensor"; VLOG(6) << "Generated Call RetainGradForTensor";
const char* RETAIN_GRAD_TEMPLATE = const char* RETAIN_GRAD_TEMPLATE =
" egr::EagerUtils::CheckAndRetainGrad(%s);\n"; " egr::EagerUtils::CheckAndRetainGrad(%s);\n";
grad_node_creation_str += grad_node_creation_str += paddle::string::Sprintf(
paddle::string::Sprintf(RETAIN_GRAD_TEMPLATE, inplace_input_name); RETAIN_GRAD_TEMPLATE, LegalizeVarName(inplace_input_name));
} }
} else { } else {
const std::string& output_autograd_name = "p_autograd_" + output_name; const std::string& output_autograd_name =
"p_autograd_" + LegalizeVarName(output_name);
size_t output_position = fwd_outputs_name_pos_map.at(output_name); size_t output_position = fwd_outputs_name_pos_map.at(output_name);
// Intermediate Tensor does not require SetHistory, nor RetainGrad // Intermediate Tensor does not require SetHistory, nor RetainGrad
...@@ -1261,7 +1279,8 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1261,7 +1279,8 @@ static std::string GenerateGradNodeCreationContent(
const char* SET_GRAD_IN_META_TEMPLATE = const char* SET_GRAD_IN_META_TEMPLATE =
" grad_node->SetGradInMeta(%s, %d);\n"; " grad_node->SetGradInMeta(%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_IN_META_TEMPLATE, output_name, output_position); SET_GRAD_IN_META_TEMPLATE, LegalizeVarName(output_name),
output_position);
} else { } else {
pass_stop_gradient_args += ", " + output_autograd_name; pass_stop_gradient_args += ", " + output_autograd_name;
...@@ -1280,7 +1299,8 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1280,7 +1299,8 @@ static std::string GenerateGradNodeCreationContent(
const char* SET_GRAD_IN_META_TEMPLATE = const char* SET_GRAD_IN_META_TEMPLATE =
" grad_node->SetGradInMeta(%s, %d);\n"; " grad_node->SetGradInMeta(%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_IN_META_TEMPLATE, output_name, output_position); SET_GRAD_IN_META_TEMPLATE, LegalizeVarName(output_name),
output_position);
} }
// Intermediate Tensor does not require CheckAndRetainGrad // Intermediate Tensor does not require CheckAndRetainGrad
...@@ -1288,8 +1308,8 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1288,8 +1308,8 @@ static std::string GenerateGradNodeCreationContent(
VLOG(6) << "Generated Call RetainGradForTensor"; VLOG(6) << "Generated Call RetainGradForTensor";
const char* RETAIN_GRAD_TEMPLATE = const char* RETAIN_GRAD_TEMPLATE =
" egr::EagerUtils::CheckAndRetainGrad(%s);\n"; " egr::EagerUtils::CheckAndRetainGrad(%s);\n";
grad_node_creation_str += grad_node_creation_str += paddle::string::Sprintf(
paddle::string::Sprintf(RETAIN_GRAD_TEMPLATE, output_name); RETAIN_GRAD_TEMPLATE, LegalizeVarName(output_name));
} }
} }
} }
...@@ -1412,9 +1432,10 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1412,9 +1432,10 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if (input.duplicable()) { if (input.duplicable()) {
const char* FWD_INS_ARG_TEMPLATE = const char* FWD_INS_ARG_TEMPLATE =
"const std::vector<paddle::experimental::Tensor>& %s"; "const std::vector<paddle::experimental::Tensor>& %s";
input_args_str_list[input_position] = input_args_str_list[input_position] = paddle::string::Sprintf(
paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name); FWD_INS_ARG_TEMPLATE, LegalizeVarName(input_name));
amp_function_call_args_str_list[input_position] = " NEW_" + input_name; amp_function_call_args_str_list[input_position] =
" NEW_" + LegalizeVarName(input_name);
core_ops_args_type_info[op_type][input_position] = "list"; core_ops_args_type_info[op_type][input_position] = "list";
} else { } else {
...@@ -1433,9 +1454,10 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1433,9 +1454,10 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if (!flag_find_input_name) { if (!flag_find_input_name) {
FWD_INS_ARG_TEMPLATE = "const paddle::experimental::Tensor& %s"; FWD_INS_ARG_TEMPLATE = "const paddle::experimental::Tensor& %s";
} }
input_args_str_list[input_position] = input_args_str_list[input_position] = paddle::string::Sprintf(
paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name); FWD_INS_ARG_TEMPLATE, LegalizeVarName(input_name));
amp_function_call_args_str_list[input_position] = " NEW_" + input_name; amp_function_call_args_str_list[input_position] =
" NEW_" + LegalizeVarName(input_name);
core_ops_args_type_info[op_type][input_position] = "tensor"; core_ops_args_type_info[op_type][input_position] = "tensor";
} }
...@@ -1445,8 +1467,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1445,8 +1467,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
const char* FWD_INS_CONTENT_TEMPLATE = const char* FWD_INS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::TrySyncToVars(%s) },"; "{ \"%s\", egr::EagerUtils::TrySyncToVars(%s) },";
ins_contents_str += paddle::string::Sprintf(FWD_INS_CONTENT_TEMPLATE, ins_contents_str += paddle::string::Sprintf(
input_name, input_name); FWD_INS_CONTENT_TEMPLATE, input_name, LegalizeVarName(input_name));
if (input.duplicable()) { if (input.duplicable()) {
const char* AMP_TENSORS_VECTOR_TEMPLATE = "%s,"; const char* AMP_TENSORS_VECTOR_TEMPLATE = "%s,";
amp_tensors_vector_str += amp_tensors_vector_str +=
...@@ -1455,16 +1477,18 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1455,16 +1477,18 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
" auto NEW_%s = egr::AmpAutoCasts(\"%s\", %s, amp_dst_dtype, " " auto NEW_%s = egr::AmpAutoCasts(\"%s\", %s, amp_dst_dtype, "
"\"%s\");\n"; "\"%s\");\n";
amp_auto_cast_str += paddle::string::Sprintf( amp_auto_cast_str += paddle::string::Sprintf(
AMP_AUTO_CAST_TEMPLATE, input_name, input_name, input_name, op_type); AMP_AUTO_CAST_TEMPLATE, LegalizeVarName(input_name), input_name,
LegalizeVarName(input_name), op_type);
} else { } else {
const char* AMP_TENSORS_VECTOR_TEMPLATE = "{%s},"; const char* AMP_TENSORS_VECTOR_TEMPLATE = "{%s},";
amp_tensors_vector_str += amp_tensors_vector_str += paddle::string::Sprintf(
paddle::string::Sprintf(AMP_TENSORS_VECTOR_TEMPLATE, input_name); AMP_TENSORS_VECTOR_TEMPLATE, LegalizeVarName(input_name));
const char* AMP_AUTO_CAST_TEMPLATE = const char* AMP_AUTO_CAST_TEMPLATE =
" auto NEW_%s = egr::AmpAutoCast(\"%s\", %s, amp_dst_dtype, " " auto NEW_%s = egr::AmpAutoCast(\"%s\", %s, amp_dst_dtype, "
"\"%s\");\n"; "\"%s\");\n";
amp_auto_cast_str += paddle::string::Sprintf( amp_auto_cast_str += paddle::string::Sprintf(
AMP_AUTO_CAST_TEMPLATE, input_name, input_name, input_name, op_type); AMP_AUTO_CAST_TEMPLATE, LegalizeVarName(input_name), input_name,
LegalizeVarName(input_name), op_type);
} }
} }
if (ins_contents_str.size() > 0) if (ins_contents_str.size() > 0)
...@@ -1500,35 +1524,41 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1500,35 +1524,41 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
" if(%s.size() > 0) " " if(%s.size() > 0) "
"ins[\"%s\"] = egr::EagerUtils::TrySyncToVars(%s);\n"; "ins[\"%s\"] = egr::EagerUtils::TrySyncToVars(%s);\n";
dispensable_ins_contents_str += paddle::string::Sprintf( dispensable_ins_contents_str += paddle::string::Sprintf(
FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name); FWD_INS_CONTENT_TEMPLATE, LegalizeVarName(input_name), input_name,
LegalizeVarName(input_name));
const char* FWD_AMP_TENSORS_VECTOR_TEMPLATE = const char* FWD_AMP_TENSORS_VECTOR_TEMPLATE =
" if(%s.size() > 0) " " if(%s.size() > 0) "
"amp_tensors_vector.push_back(%s);\n"; "amp_tensors_vector.push_back(%s);\n";
dispensable_amp_tensors_vector_str += paddle::string::Sprintf( dispensable_amp_tensors_vector_str += paddle::string::Sprintf(
FWD_AMP_TENSORS_VECTOR_TEMPLATE, input_name, input_name); FWD_AMP_TENSORS_VECTOR_TEMPLATE, LegalizeVarName(input_name),
LegalizeVarName(input_name));
const char* DISPENSABLE_AMP_AUTO_CAST_TEMPLATE = const char* DISPENSABLE_AMP_AUTO_CAST_TEMPLATE =
" auto NEW_%s = ((%s.size() > 0) ? egr::AmpAutoCasts(\"%s\", " " auto NEW_%s = ((%s.size() > 0) ? egr::AmpAutoCasts(\"%s\", "
"%s, amp_dst_dtype, \"%s\") : %s);\n"; "%s, amp_dst_dtype, \"%s\") : %s);\n";
dispensable_amp_auto_cast_str += paddle::string::Sprintf( dispensable_amp_auto_cast_str += paddle::string::Sprintf(
DISPENSABLE_AMP_AUTO_CAST_TEMPLATE, input_name, input_name, DISPENSABLE_AMP_AUTO_CAST_TEMPLATE, LegalizeVarName(input_name),
input_name, input_name, op_type, input_name); LegalizeVarName(input_name), input_name,
LegalizeVarName(input_name), op_type, LegalizeVarName(input_name));
} else { } else {
const char* FWD_INS_CONTENT_TEMPLATE = const char* FWD_INS_CONTENT_TEMPLATE =
" if(%s.initialized()) " " if(%s.initialized()) "
"ins[\"%s\"] = egr::EagerUtils::TrySyncToVars(%s);\n"; "ins[\"%s\"] = egr::EagerUtils::TrySyncToVars(%s);\n";
dispensable_ins_contents_str += paddle::string::Sprintf( dispensable_ins_contents_str += paddle::string::Sprintf(
FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name); FWD_INS_CONTENT_TEMPLATE, LegalizeVarName(input_name), input_name,
LegalizeVarName(input_name));
const char* FWD_AMP_TENSORS_VECTOR_TEMPLATE = const char* FWD_AMP_TENSORS_VECTOR_TEMPLATE =
" if(%s.initialized()) " " if(%s.initialized()) "
"amp_tensors_vector.push_back({ %s });\n"; "amp_tensors_vector.push_back({ %s });\n";
dispensable_amp_tensors_vector_str += paddle::string::Sprintf( dispensable_amp_tensors_vector_str += paddle::string::Sprintf(
FWD_AMP_TENSORS_VECTOR_TEMPLATE, input_name, input_name); FWD_AMP_TENSORS_VECTOR_TEMPLATE, LegalizeVarName(input_name),
LegalizeVarName(input_name));
const char* DISPENSABLE_AMP_AUTO_CAST_TEMPLATE = const char* DISPENSABLE_AMP_AUTO_CAST_TEMPLATE =
" auto NEW_%s = ((%s.initialized()) ? egr::AmpAutoCast(\"%s\", " " auto NEW_%s = ((%s.initialized()) ? egr::AmpAutoCast(\"%s\", "
"%s, amp_dst_dtype, \"%s\") : %s);\n"; "%s, amp_dst_dtype, \"%s\") : %s);\n";
dispensable_amp_auto_cast_str += paddle::string::Sprintf( dispensable_amp_auto_cast_str += paddle::string::Sprintf(
DISPENSABLE_AMP_AUTO_CAST_TEMPLATE, input_name, input_name, DISPENSABLE_AMP_AUTO_CAST_TEMPLATE, LegalizeVarName(input_name),
input_name, input_name, op_type, input_name); LegalizeVarName(input_name), input_name,
LegalizeVarName(input_name), op_type, LegalizeVarName(input_name));
} }
} }
} }
...@@ -1550,18 +1580,18 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1550,18 +1580,18 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if (output.duplicable()) { if (output.duplicable()) {
const char* FWD_NUM_ARG_TEMPLATE = const char* FWD_NUM_ARG_TEMPLATE =
", std::vector<paddle::experimental::Tensor*>& %s"; ", std::vector<paddle::experimental::Tensor*>& %s";
std::string arg_str = std::string arg_str = paddle::string::Sprintf(
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name); FWD_NUM_ARG_TEMPLATE, LegalizeVarName(output_var_name));
dygraph_function_args_str += arg_str; dygraph_function_args_str += arg_str;
amp_function_call_args_str += (", " + output_var_name); amp_function_call_args_str += (", " + LegalizeVarName(output_var_name));
core_ops_args_type_info[op_type].push_back("list"); core_ops_args_type_info[op_type].push_back("list");
} else { } else {
const char* FWD_NUM_ARG_TEMPLATE = ", paddle::experimental::Tensor* %s"; const char* FWD_NUM_ARG_TEMPLATE = ", paddle::experimental::Tensor* %s";
std::string arg_str = std::string arg_str = paddle::string::Sprintf(
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name); FWD_NUM_ARG_TEMPLATE, LegalizeVarName(output_var_name));
dygraph_function_args_str += arg_str; dygraph_function_args_str += arg_str;
amp_function_call_args_str += (", " + output_var_name); amp_function_call_args_str += (", " + LegalizeVarName(output_var_name));
core_ops_args_type_info[op_type].push_back("tensor"); core_ops_args_type_info[op_type].push_back("tensor");
} }
...@@ -1577,8 +1607,9 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1577,8 +1607,9 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
} else { } else {
const char* FWD_OUTS_CONTENT_TEMPLATE = const char* FWD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::TrySyncToVars(%s) },"; "{ \"%s\", egr::EagerUtils::TrySyncToVars(%s) },";
outs_contents_str += paddle::string::Sprintf( outs_contents_str +=
FWD_OUTS_CONTENT_TEMPLATE, output_name, output_var_name); paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, output_name,
LegalizeVarName(output_var_name));
} }
core_ops_args_info[op_type].push_back(output_name); core_ops_args_info[op_type].push_back(output_name);
...@@ -1773,7 +1804,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1773,7 +1804,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
std::vector<std::string> return_types(output_size); std::vector<std::string> return_types(output_size);
for (const proto::OpProto::Var& output : out_vars) { for (const proto::OpProto::Var& output : out_vars) {
const std::string& output_name = output.name(); const std::string& output_name = output.name();
const std::string output_var_args_name = output_name + "Var"; const std::string output_var_args_name =
LegalizeVariableName(output_name + "Var");
std::string out_tensor_str; std::string out_tensor_str;
size_t return_position = fwd_outputs_name_pos_map.at(output_name); size_t return_position = fwd_outputs_name_pos_map.at(output_name);
std::string output_varname = LegalizeVariableName(output_name); std::string output_varname = LegalizeVariableName(output_name);
...@@ -1837,9 +1869,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1837,9 +1869,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
" %s.bump_inplace_version();\n" " %s.bump_inplace_version();\n"
" VLOG(3) << \"Tensor(\" << %s.name() << \") uses Inplace " " VLOG(3) << \"Tensor(\" << %s.name() << \") uses Inplace "
"Strategy.\";\n"; "Strategy.\";\n";
out_tensor_str = paddle::string::Sprintf( out_tensor_str =
FWD_OUT_TENSOR_TEMPLATE, output_name, inplace_input_name, paddle::string::Sprintf(FWD_OUT_TENSOR_TEMPLATE, output_name,
inplace_input_name, inplace_input_name); LegalizeVarName(inplace_input_name),
LegalizeVarName(inplace_input_name),
LegalizeVarName(inplace_input_name));
} else { } else {
const char* FWD_OUT_TENSOR_TEMPLATE = const char* FWD_OUT_TENSOR_TEMPLATE =
" paddle::experimental::Tensor %s;\n" " paddle::experimental::Tensor %s;\n"
...@@ -1854,7 +1888,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1854,7 +1888,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if (!inplace_map.empty() && inplace_map.count(output_name)) { if (!inplace_map.empty() && inplace_map.count(output_name)) {
// Replace output directly with input in inplace op. // Replace output directly with input in inplace op.
return_contents[return_position] = inplace_map[output_name]; return_contents[return_position] =
LegalizeVarName(inplace_map[output_name]);
} else { } else {
return_contents[return_position] = output_varname; return_contents[return_position] = output_varname;
} }
......
...@@ -36,6 +36,11 @@ ...@@ -36,6 +36,11 @@
// phi // phi
#include "paddle/phi/kernels/declarations.h" #include "paddle/phi/kernels/declarations.h"
static std::string LegalizeVarName(const std::string& var_name) {
std::string ret = var_name;
std::replace(ret.begin(), ret.end(), '@', '_'); // replace all '-' to '_'
return ret;
}
// clang-format off // clang-format off
const char* OUT_INITIALIZER_TEMPLATE = const char* OUT_INITIALIZER_TEMPLATE =
R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase("auto_"+std::to_string(VarBaseUniqueNameID++)+"_"))}})"; R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase("auto_"+std::to_string(VarBaseUniqueNameID++)+"_"))}})";
...@@ -185,18 +190,19 @@ std::string GenerateOpFunctionsBody( ...@@ -185,18 +190,19 @@ std::string GenerateOpFunctionsBody(
continue; continue;
} }
const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE; const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE;
auto input_arg = auto input_arg = paddle::string::Sprintf(
paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name)); ARG_TEMPLATE, in_type, TempName(LegalizeVarName(in_name)));
input_args += input_arg; input_args += input_arg;
input_args += ","; input_args += ",";
input_args_num++; input_args_num++;
const auto in_cast_type = const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false"; auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, op_type, ins_cast_str +=
in_name, arg_idx++, dispensable); paddle::string::Sprintf(in_cast_type, LegalizeVarName(in_name), op_type,
in_name, arg_idx++, dispensable);
call_api_str += in_name + ", "; call_api_str += LegalizeVarName(in_name) + ", ";
} }
if (!input_args.empty() && input_args.back() == ',') { if (!input_args.empty() && input_args.back() == ',') {
...@@ -224,7 +230,7 @@ std::string GenerateOpFunctionsBody( ...@@ -224,7 +230,7 @@ std::string GenerateOpFunctionsBody(
input_args += ","; input_args += ",";
} }
input_args += out_type; input_args += out_type;
input_args += out_name; input_args += LegalizeVarName(out_name);
input_args_num++; input_args_num++;
if (output.dispensable()) { if (output.dispensable()) {
...@@ -237,18 +243,19 @@ std::string GenerateOpFunctionsBody( ...@@ -237,18 +243,19 @@ std::string GenerateOpFunctionsBody(
const auto out_template = output.duplicable() const auto out_template = output.duplicable()
? INPUT_LIST_INITIALIZER_TEMPLATE ? INPUT_LIST_INITIALIZER_TEMPLATE
: INPUT_INITIALIZER_TEMPLATE; : INPUT_INITIALIZER_TEMPLATE;
outs_initializer += outs_initializer += paddle::string::Sprintf(out_template, out_name,
paddle::string::Sprintf(out_template, out_name, out_name); LegalizeVarName(out_name));
outs_initializer += ","; outs_initializer += ",";
} }
const auto in_cast_type = output.duplicable() ? CAST_VAR_PTR_LIST_TEMPLATE const auto in_cast_type = output.duplicable() ? CAST_VAR_PTR_LIST_TEMPLATE
: CAST_VAR_PTR_TEMPLATE; : CAST_VAR_PTR_TEMPLATE;
auto dispensable = output.dispensable() ? "true" : "false"; auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, op_type, ins_cast_str +=
out_name, arg_idx++, dispensable); paddle::string::Sprintf(in_cast_type, LegalizeVarName(out_name),
op_type, out_name, arg_idx++, dispensable);
call_api_str += out_name + ", "; call_api_str += LegalizeVarName(out_name) + ", ";
} else { } else {
// There are few Operators that have duplicable output, like `Out` in // There are few Operators that have duplicable output, like `Out` in
// split op. We need to specify the number of variables for the // split op. We need to specify the number of variables for the
...@@ -257,7 +264,8 @@ std::string GenerateOpFunctionsBody( ...@@ -257,7 +264,8 @@ std::string GenerateOpFunctionsBody(
if (input_args != "") { if (input_args != "") {
input_args += ","; input_args += ",";
} }
auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name); auto out_num_str =
paddle::string::Sprintf(ARG_OUT_NUM, LegalizeVarName(out_name));
input_args += ARG_OUT_NUM_TYPE; input_args += ARG_OUT_NUM_TYPE;
input_args += out_num_str; input_args += out_num_str;
input_args_num++; input_args_num++;
......
...@@ -35,6 +35,12 @@ ...@@ -35,6 +35,12 @@
// phi // phi
#include "paddle/phi/kernels/declarations.h" #include "paddle/phi/kernels/declarations.h"
static std::string LegalizeVarName(const std::string& var_name) {
std::string ret = var_name;
std::replace(ret.begin(), ret.end(), '@', '_'); // replace all '-' to '_'
return ret;
}
// NOTE(pangyoki): Inplace OP with duplicable input. // NOTE(pangyoki): Inplace OP with duplicable input.
// The set includes inplace ops that have duplicable input. // The set includes inplace ops that have duplicable input.
// The first Varbase in input needs to be specified for the inplace strategy // The first Varbase in input needs to be specified for the inplace strategy
...@@ -201,28 +207,31 @@ std::string GenerateOpFunctionsBody( ...@@ -201,28 +207,31 @@ std::string GenerateOpFunctionsBody(
continue; continue;
} }
const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE; const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE;
auto input_arg = auto input_arg = paddle::string::Sprintf(
paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name)); ARG_TEMPLATE, in_type, LegalizeVarName(TempName(in_name)));
input_args += input_arg; input_args += input_arg;
input_args += ","; input_args += ",";
input_args_num++; input_args_num++;
const auto in_cast_type = const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false"; auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, in_name, ins_cast_str +=
arg_idx++, dispensable); paddle::string::Sprintf(in_cast_type, LegalizeVarName(in_name), in_name,
arg_idx++, dispensable);
if (input.dispensable()) { if (input.dispensable()) {
const auto in_template = input.duplicable() const auto in_template = input.duplicable()
? INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST ? INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
: INPUT_INITIALIZER_TEMPLATE_WITH_NULL; : INPUT_INITIALIZER_TEMPLATE_WITH_NULL;
ins_initializer_with_null += ins_initializer_with_null +=
paddle::string::Sprintf(in_template, in_name, in_name, in_name); paddle::string::Sprintf(in_template, LegalizeVarName(in_name),
in_name, LegalizeVarName(in_name));
} else { } else {
const auto in_template = input.duplicable() const auto in_template = input.duplicable()
? INPUT_LIST_INITIALIZER_TEMPLATE ? INPUT_LIST_INITIALIZER_TEMPLATE
: INPUT_INITIALIZER_TEMPLATE; : INPUT_INITIALIZER_TEMPLATE;
ins_initializer += paddle::string::Sprintf(in_template, in_name, in_name); ins_initializer += paddle::string::Sprintf(in_template, in_name,
LegalizeVarName(in_name));
ins_initializer += ","; ins_initializer += ",";
} }
} }
...@@ -259,7 +268,7 @@ std::string GenerateOpFunctionsBody( ...@@ -259,7 +268,7 @@ std::string GenerateOpFunctionsBody(
input_args += ","; input_args += ",";
} }
input_args += out_type; input_args += out_type;
input_args += out_name; input_args += LegalizeVarName(out_name);
input_args_num++; input_args_num++;
if (output.dispensable()) { if (output.dispensable()) {
...@@ -272,16 +281,17 @@ std::string GenerateOpFunctionsBody( ...@@ -272,16 +281,17 @@ std::string GenerateOpFunctionsBody(
const auto out_template = output.duplicable() const auto out_template = output.duplicable()
? INPUT_LIST_INITIALIZER_TEMPLATE ? INPUT_LIST_INITIALIZER_TEMPLATE
: INPUT_INITIALIZER_TEMPLATE; : INPUT_INITIALIZER_TEMPLATE;
outs_initializer += outs_initializer += paddle::string::Sprintf(out_template, out_name,
paddle::string::Sprintf(out_template, out_name, out_name); LegalizeVarName(out_name));
outs_initializer += ","; outs_initializer += ",";
} }
const auto in_cast_type = const auto in_cast_type =
output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = output.dispensable() ? "true" : "false"; auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, out_name, ins_cast_str +=
arg_idx++, dispensable); paddle::string::Sprintf(in_cast_type, LegalizeVarName(out_name),
out_name, arg_idx++, dispensable);
} else if (use_inplace_strategy && inplace_map.count(out_name)) { } else if (use_inplace_strategy && inplace_map.count(out_name)) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
inplace_map[out_name], "", inplace_map[out_name], "",
...@@ -307,11 +317,13 @@ std::string GenerateOpFunctionsBody( ...@@ -307,11 +317,13 @@ std::string GenerateOpFunctionsBody(
// Leaf Var that doesn't stop gradient can't use inplace strategy. // Leaf Var that doesn't stop gradient can't use inplace strategy.
// Increase inplace_version. // Increase inplace_version.
inplace_strategy_str += paddle::string::Sprintf( inplace_strategy_str += paddle::string::Sprintf(
INPLACE_STRATEGY_TEMPLATE, inplace_input_name, inplace_input_name, INPLACE_STRATEGY_TEMPLATE, LegalizeVarName(inplace_input_name),
INPLACE_LEAF_ERROR_MESSAGE, inplace_input_name, inplace_input_name, LegalizeVarName(inplace_input_name), INPLACE_LEAF_ERROR_MESSAGE,
inplace_input_name); LegalizeVarName(inplace_input_name),
outs_initializer += LegalizeVarName(inplace_input_name),
paddle::string::Sprintf(out_template, out_name, inplace_input_name); LegalizeVarName(inplace_input_name));
outs_initializer += paddle::string::Sprintf(
out_template, out_name, LegalizeVarName(inplace_input_name));
outs_initializer += ","; outs_initializer += ",";
} else { } else {
// There are few Operators that have duplicable output, like `Out` in // There are few Operators that have duplicable output, like `Out` in
...@@ -321,7 +333,8 @@ std::string GenerateOpFunctionsBody( ...@@ -321,7 +333,8 @@ std::string GenerateOpFunctionsBody(
if (input_args != "") { if (input_args != "") {
input_args += ","; input_args += ",";
} }
auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name); auto out_num_str =
paddle::string::Sprintf(ARG_OUT_NUM, LegalizeVarName(out_name));
input_args += ARG_OUT_NUM_TYPE; input_args += ARG_OUT_NUM_TYPE;
input_args += out_num_str; input_args += out_num_str;
input_args_num++; input_args_num++;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册