diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 6b962b537edf97456737b290a21767f93b364ff7..44fa8461f2fe94e7af4a3e48d55aa119d9fde4e6 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -56,6 +56,13 @@ static std::unordered_set black_ops_list = {"run_program"}; static std::string LegalizeVariableName(const std::string& var_name) { std::string ret = var_name; std::replace(ret.begin(), ret.end(), '-', '_'); // replace all '-' to '_' + std::replace(ret.begin(), ret.end(), '@', '_'); // replace all '-' to '_' + return ret; +} + +static std::string LegalizeVarName(const std::string& var_name) { + std::string ret = var_name; + std::replace(ret.begin(), ret.end(), '@', '_'); // replace all '-' to '_' return ret; } @@ -1024,7 +1031,8 @@ static std::string GenerateGradNodeCreationContent( // egr::EagerUtils::autograd_meta("op_proto.outputs()[0].name()")" for (const proto::OpProto::Var& output : out_vars) { const std::string& output_name = output.name(); - const std::string& output_autograd_name = "p_autograd_" + output_name; + const std::string& output_autograd_name = + "p_autograd_" + LegalizeVarName(output_name); // output autograd_meta should be got after running TraceOP. if (output.duplicable()) { @@ -1032,12 +1040,13 @@ static std::string GenerateGradNodeCreationContent( " std::vector %s = " "egr::EagerUtils::autograd_meta(&%s);\n"; get_output_autograd_meta_str += paddle::string::Sprintf( - GET_MULTI_AUTOGRAD_META_TEMPLATE, output_autograd_name, output_name); + GET_MULTI_AUTOGRAD_META_TEMPLATE, output_autograd_name, + LegalizeVarName(output_name)); } else { // In inplace op, the case where output is duplicable is not considered. // Replace output directly with input in inplace op. if (!inplace_map.empty() && inplace_map.count(output_name)) { - auto inplace_input_name = inplace_map[output_name]; + auto inplace_input_name = LegalizeVarName(inplace_map[output_name]); const std::string& inplace_input_autograd_name = "p_autograd_" + inplace_input_name; const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE = @@ -1049,9 +1058,9 @@ static std::string GenerateGradNodeCreationContent( const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE = " egr::AutogradMeta* %s = " "egr::EagerUtils::autograd_meta(&%s);\n"; - get_output_autograd_meta_str += - paddle::string::Sprintf(GET_SINGLE_AUTOGRAD_META_TEMPLATE, - output_autograd_name, output_name); + get_output_autograd_meta_str += paddle::string::Sprintf( + GET_SINGLE_AUTOGRAD_META_TEMPLATE, output_autograd_name, + LegalizeVarName(output_name)); } } } @@ -1061,28 +1070,32 @@ static std::string GenerateGradNodeCreationContent( // inplace). for (const proto::OpProto::Var& input : in_vars) { const std::string& input_name = input.name(); - const std::string& input_autograd_name = "p_autograd_" + input_name; + const std::string& input_autograd_name = + "p_autograd_" + LegalizeVarName(input_name); if (input.duplicable()) { const char* GET_MULTI_AUTOGRAD_META_TEMPLATE = " std::vector %s = " "egr::EagerUtils::nullable_autograd_meta(%s);\n"; get_input_autograd_meta_str += paddle::string::Sprintf( - GET_MULTI_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name); + GET_MULTI_AUTOGRAD_META_TEMPLATE, input_autograd_name, + LegalizeVarName(input_name)); } else if (input.dispensable()) { const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE = " egr::AutogradMeta* %s = " "egr::EagerUtils::nullable_autograd_meta(%s);\n"; get_input_autograd_meta_str += paddle::string::Sprintf( - GET_SINGLE_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name); + GET_SINGLE_AUTOGRAD_META_TEMPLATE, input_autograd_name, + LegalizeVarName(input_name)); } else { const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE = " egr::AutogradMeta* %s = " "egr::EagerUtils::nullable_autograd_meta(%s);\n"; get_input_autograd_meta_str += paddle::string::Sprintf( - GET_SINGLE_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name); + GET_SINGLE_AUTOGRAD_META_TEMPLATE, input_autograd_name, + LegalizeVarName(input_name)); } } VLOG(6) << "Generated inputs autograd_meta"; @@ -1096,7 +1109,7 @@ static std::string GenerateGradNodeCreationContent( " egr::EagerUtils::CheckInplace(%s, p_autograd_%s, " "require_any_grad);\n"; for (auto& inplace_pair : inplace_map) { - std::string inplace_name = inplace_pair.second; + std::string inplace_name = LegalizeVarName(inplace_pair.second); check_inplace_str += paddle::string::Sprintf(CHECKING_INPLACE_TEMPLATE, inplace_name, inplace_name); } @@ -1159,12 +1172,12 @@ static std::string GenerateGradNodeCreationContent( if (!inplace_map.empty() && inplace_map.count(tensor_wrapper_name)) { auto inplace_input_name = inplace_map[tensor_wrapper_name]; grad_node_creation_str += paddle::string::Sprintf( - SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name, - inplace_input_name, full_reserved); + SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name), + LegalizeVarName(inplace_input_name), full_reserved); } else { grad_node_creation_str += paddle::string::Sprintf( - SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name, - tensor_wrapper_name, full_reserved); + SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name), + LegalizeVarName(tensor_wrapper_name), full_reserved); } } } @@ -1176,7 +1189,8 @@ static std::string GenerateGradNodeCreationContent( std::string compute_require_grad_args = "trace_backward"; for (const proto::OpProto::Var& input : in_vars) { const std::string& input_name = input.name(); - const std::string& input_autograd_name = "p_autograd_" + input_name; + const std::string& input_autograd_name = + "p_autograd_" + LegalizeVarName(input_name); if (!input.duplicable()) { compute_require_grad_args += ", " + input_autograd_name; @@ -1184,8 +1198,9 @@ static std::string GenerateGradNodeCreationContent( const char* SET_GRAD_OUT_META_TEMPLATE = " grad_node->SetGradOutMeta(%s, %d);\n"; - grad_node_creation_str += paddle::string::Sprintf( - SET_GRAD_OUT_META_TEMPLATE, input_name, input_position); + grad_node_creation_str += + paddle::string::Sprintf(SET_GRAD_OUT_META_TEMPLATE, + LegalizeVarName(input_name), input_position); } else { compute_require_grad_args += ", &" + input_autograd_name; @@ -1193,8 +1208,9 @@ static std::string GenerateGradNodeCreationContent( const char* SET_GRAD_OUT_META_TEMPLATE = " grad_node->SetGradOutMeta(%s, %d);\n"; - grad_node_creation_str += paddle::string::Sprintf( - SET_GRAD_OUT_META_TEMPLATE, input_name, input_position); + grad_node_creation_str += + paddle::string::Sprintf(SET_GRAD_OUT_META_TEMPLATE, + LegalizeVarName(input_name), input_position); } } @@ -1208,7 +1224,7 @@ static std::string GenerateGradNodeCreationContent( if (!inplace_map.empty() && inplace_map.count(output_name)) { auto inplace_input_name = inplace_map[output_name]; const std::string& inplace_input_autograd_name = - "p_autograd_" + inplace_input_name; + "p_autograd_" + LegalizeVarName(inplace_input_name); size_t output_position = fwd_outputs_name_pos_map.at(output_name); // Intermediate Tensor does not require SetHistory, nor RetainGrad @@ -1228,18 +1244,20 @@ static std::string GenerateGradNodeCreationContent( const char* SET_GRAD_IN_META_TEMPLATE = " grad_node->SetGradInMeta(%s, %d);\n"; grad_node_creation_str += paddle::string::Sprintf( - SET_GRAD_IN_META_TEMPLATE, inplace_input_name, output_position); + SET_GRAD_IN_META_TEMPLATE, LegalizeVarName(inplace_input_name), + output_position); // Intermediate Tensor does not require CheckAndRetainGrad if (!output.intermediate()) { VLOG(6) << "Generated Call RetainGradForTensor"; const char* RETAIN_GRAD_TEMPLATE = " egr::EagerUtils::CheckAndRetainGrad(%s);\n"; - grad_node_creation_str += - paddle::string::Sprintf(RETAIN_GRAD_TEMPLATE, inplace_input_name); + grad_node_creation_str += paddle::string::Sprintf( + RETAIN_GRAD_TEMPLATE, LegalizeVarName(inplace_input_name)); } } else { - const std::string& output_autograd_name = "p_autograd_" + output_name; + const std::string& output_autograd_name = + "p_autograd_" + LegalizeVarName(output_name); size_t output_position = fwd_outputs_name_pos_map.at(output_name); // Intermediate Tensor does not require SetHistory, nor RetainGrad @@ -1261,7 +1279,8 @@ static std::string GenerateGradNodeCreationContent( const char* SET_GRAD_IN_META_TEMPLATE = " grad_node->SetGradInMeta(%s, %d);\n"; grad_node_creation_str += paddle::string::Sprintf( - SET_GRAD_IN_META_TEMPLATE, output_name, output_position); + SET_GRAD_IN_META_TEMPLATE, LegalizeVarName(output_name), + output_position); } else { pass_stop_gradient_args += ", " + output_autograd_name; @@ -1280,7 +1299,8 @@ static std::string GenerateGradNodeCreationContent( const char* SET_GRAD_IN_META_TEMPLATE = " grad_node->SetGradInMeta(%s, %d);\n"; grad_node_creation_str += paddle::string::Sprintf( - SET_GRAD_IN_META_TEMPLATE, output_name, output_position); + SET_GRAD_IN_META_TEMPLATE, LegalizeVarName(output_name), + output_position); } // Intermediate Tensor does not require CheckAndRetainGrad @@ -1288,8 +1308,8 @@ static std::string GenerateGradNodeCreationContent( VLOG(6) << "Generated Call RetainGradForTensor"; const char* RETAIN_GRAD_TEMPLATE = " egr::EagerUtils::CheckAndRetainGrad(%s);\n"; - grad_node_creation_str += - paddle::string::Sprintf(RETAIN_GRAD_TEMPLATE, output_name); + grad_node_creation_str += paddle::string::Sprintf( + RETAIN_GRAD_TEMPLATE, LegalizeVarName(output_name)); } } } @@ -1412,9 +1432,10 @@ static std::pair GenerateForwardFunctionContents( if (input.duplicable()) { const char* FWD_INS_ARG_TEMPLATE = "const std::vector& %s"; - input_args_str_list[input_position] = - paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name); - amp_function_call_args_str_list[input_position] = " NEW_" + input_name; + input_args_str_list[input_position] = paddle::string::Sprintf( + FWD_INS_ARG_TEMPLATE, LegalizeVarName(input_name)); + amp_function_call_args_str_list[input_position] = + " NEW_" + LegalizeVarName(input_name); core_ops_args_type_info[op_type][input_position] = "list"; } else { @@ -1433,9 +1454,10 @@ static std::pair GenerateForwardFunctionContents( if (!flag_find_input_name) { FWD_INS_ARG_TEMPLATE = "const paddle::experimental::Tensor& %s"; } - input_args_str_list[input_position] = - paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name); - amp_function_call_args_str_list[input_position] = " NEW_" + input_name; + input_args_str_list[input_position] = paddle::string::Sprintf( + FWD_INS_ARG_TEMPLATE, LegalizeVarName(input_name)); + amp_function_call_args_str_list[input_position] = + " NEW_" + LegalizeVarName(input_name); core_ops_args_type_info[op_type][input_position] = "tensor"; } @@ -1445,8 +1467,8 @@ static std::pair GenerateForwardFunctionContents( const char* FWD_INS_CONTENT_TEMPLATE = "{ \"%s\", egr::EagerUtils::TrySyncToVars(%s) },"; - ins_contents_str += paddle::string::Sprintf(FWD_INS_CONTENT_TEMPLATE, - input_name, input_name); + ins_contents_str += paddle::string::Sprintf( + FWD_INS_CONTENT_TEMPLATE, input_name, LegalizeVarName(input_name)); if (input.duplicable()) { const char* AMP_TENSORS_VECTOR_TEMPLATE = "%s,"; amp_tensors_vector_str += @@ -1455,16 +1477,18 @@ static std::pair GenerateForwardFunctionContents( " auto NEW_%s = egr::AmpAutoCasts(\"%s\", %s, amp_dst_dtype, " "\"%s\");\n"; amp_auto_cast_str += paddle::string::Sprintf( - AMP_AUTO_CAST_TEMPLATE, input_name, input_name, input_name, op_type); + AMP_AUTO_CAST_TEMPLATE, LegalizeVarName(input_name), input_name, + LegalizeVarName(input_name), op_type); } else { const char* AMP_TENSORS_VECTOR_TEMPLATE = "{%s},"; - amp_tensors_vector_str += - paddle::string::Sprintf(AMP_TENSORS_VECTOR_TEMPLATE, input_name); + amp_tensors_vector_str += paddle::string::Sprintf( + AMP_TENSORS_VECTOR_TEMPLATE, LegalizeVarName(input_name)); const char* AMP_AUTO_CAST_TEMPLATE = " auto NEW_%s = egr::AmpAutoCast(\"%s\", %s, amp_dst_dtype, " "\"%s\");\n"; amp_auto_cast_str += paddle::string::Sprintf( - AMP_AUTO_CAST_TEMPLATE, input_name, input_name, input_name, op_type); + AMP_AUTO_CAST_TEMPLATE, LegalizeVarName(input_name), input_name, + LegalizeVarName(input_name), op_type); } } if (ins_contents_str.size() > 0) @@ -1500,35 +1524,41 @@ static std::pair GenerateForwardFunctionContents( " if(%s.size() > 0) " "ins[\"%s\"] = egr::EagerUtils::TrySyncToVars(%s);\n"; dispensable_ins_contents_str += paddle::string::Sprintf( - FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name); + FWD_INS_CONTENT_TEMPLATE, LegalizeVarName(input_name), input_name, + LegalizeVarName(input_name)); const char* FWD_AMP_TENSORS_VECTOR_TEMPLATE = " if(%s.size() > 0) " "amp_tensors_vector.push_back(%s);\n"; dispensable_amp_tensors_vector_str += paddle::string::Sprintf( - FWD_AMP_TENSORS_VECTOR_TEMPLATE, input_name, input_name); + FWD_AMP_TENSORS_VECTOR_TEMPLATE, LegalizeVarName(input_name), + LegalizeVarName(input_name)); const char* DISPENSABLE_AMP_AUTO_CAST_TEMPLATE = " auto NEW_%s = ((%s.size() > 0) ? egr::AmpAutoCasts(\"%s\", " "%s, amp_dst_dtype, \"%s\") : %s);\n"; dispensable_amp_auto_cast_str += paddle::string::Sprintf( - DISPENSABLE_AMP_AUTO_CAST_TEMPLATE, input_name, input_name, - input_name, input_name, op_type, input_name); + DISPENSABLE_AMP_AUTO_CAST_TEMPLATE, LegalizeVarName(input_name), + LegalizeVarName(input_name), input_name, + LegalizeVarName(input_name), op_type, LegalizeVarName(input_name)); } else { const char* FWD_INS_CONTENT_TEMPLATE = " if(%s.initialized()) " "ins[\"%s\"] = egr::EagerUtils::TrySyncToVars(%s);\n"; dispensable_ins_contents_str += paddle::string::Sprintf( - FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name); + FWD_INS_CONTENT_TEMPLATE, LegalizeVarName(input_name), input_name, + LegalizeVarName(input_name)); const char* FWD_AMP_TENSORS_VECTOR_TEMPLATE = " if(%s.initialized()) " "amp_tensors_vector.push_back({ %s });\n"; dispensable_amp_tensors_vector_str += paddle::string::Sprintf( - FWD_AMP_TENSORS_VECTOR_TEMPLATE, input_name, input_name); + FWD_AMP_TENSORS_VECTOR_TEMPLATE, LegalizeVarName(input_name), + LegalizeVarName(input_name)); const char* DISPENSABLE_AMP_AUTO_CAST_TEMPLATE = " auto NEW_%s = ((%s.initialized()) ? egr::AmpAutoCast(\"%s\", " "%s, amp_dst_dtype, \"%s\") : %s);\n"; dispensable_amp_auto_cast_str += paddle::string::Sprintf( - DISPENSABLE_AMP_AUTO_CAST_TEMPLATE, input_name, input_name, - input_name, input_name, op_type, input_name); + DISPENSABLE_AMP_AUTO_CAST_TEMPLATE, LegalizeVarName(input_name), + LegalizeVarName(input_name), input_name, + LegalizeVarName(input_name), op_type, LegalizeVarName(input_name)); } } } @@ -1550,18 +1580,18 @@ static std::pair GenerateForwardFunctionContents( if (output.duplicable()) { const char* FWD_NUM_ARG_TEMPLATE = ", std::vector& %s"; - std::string arg_str = - paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name); + std::string arg_str = paddle::string::Sprintf( + FWD_NUM_ARG_TEMPLATE, LegalizeVarName(output_var_name)); dygraph_function_args_str += arg_str; - amp_function_call_args_str += (", " + output_var_name); + amp_function_call_args_str += (", " + LegalizeVarName(output_var_name)); core_ops_args_type_info[op_type].push_back("list"); } else { const char* FWD_NUM_ARG_TEMPLATE = ", paddle::experimental::Tensor* %s"; - std::string arg_str = - paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name); + std::string arg_str = paddle::string::Sprintf( + FWD_NUM_ARG_TEMPLATE, LegalizeVarName(output_var_name)); dygraph_function_args_str += arg_str; - amp_function_call_args_str += (", " + output_var_name); + amp_function_call_args_str += (", " + LegalizeVarName(output_var_name)); core_ops_args_type_info[op_type].push_back("tensor"); } @@ -1577,8 +1607,9 @@ static std::pair GenerateForwardFunctionContents( } else { const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", egr::EagerUtils::TrySyncToVars(%s) },"; - outs_contents_str += paddle::string::Sprintf( - FWD_OUTS_CONTENT_TEMPLATE, output_name, output_var_name); + outs_contents_str += + paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, output_name, + LegalizeVarName(output_var_name)); } core_ops_args_info[op_type].push_back(output_name); @@ -1773,7 +1804,8 @@ static std::pair GenerateForwardFunctionContents( std::vector return_types(output_size); for (const proto::OpProto::Var& output : out_vars) { const std::string& output_name = output.name(); - const std::string output_var_args_name = output_name + "Var"; + const std::string output_var_args_name = + LegalizeVariableName(output_name + "Var"); std::string out_tensor_str; size_t return_position = fwd_outputs_name_pos_map.at(output_name); std::string output_varname = LegalizeVariableName(output_name); @@ -1837,9 +1869,11 @@ static std::pair GenerateForwardFunctionContents( " %s.bump_inplace_version();\n" " VLOG(3) << \"Tensor(\" << %s.name() << \") uses Inplace " "Strategy.\";\n"; - out_tensor_str = paddle::string::Sprintf( - FWD_OUT_TENSOR_TEMPLATE, output_name, inplace_input_name, - inplace_input_name, inplace_input_name); + out_tensor_str = + paddle::string::Sprintf(FWD_OUT_TENSOR_TEMPLATE, output_name, + LegalizeVarName(inplace_input_name), + LegalizeVarName(inplace_input_name), + LegalizeVarName(inplace_input_name)); } else { const char* FWD_OUT_TENSOR_TEMPLATE = " paddle::experimental::Tensor %s;\n" @@ -1854,7 +1888,8 @@ static std::pair GenerateForwardFunctionContents( if (!inplace_map.empty() && inplace_map.count(output_name)) { // Replace output directly with input in inplace op. - return_contents[return_position] = inplace_map[output_name]; + return_contents[return_position] = + LegalizeVarName(inplace_map[output_name]); } else { return_contents[return_position] = output_varname; } diff --git a/paddle/fluid/pybind/eager_op_function_generator.cc b/paddle/fluid/pybind/eager_op_function_generator.cc index 2ac12165c1a66c0379442284c6ad68f6c2c32bfe..b546aa2d76bcd2264abde045d3d09d8c04a17762 100644 --- a/paddle/fluid/pybind/eager_op_function_generator.cc +++ b/paddle/fluid/pybind/eager_op_function_generator.cc @@ -36,6 +36,11 @@ // phi #include "paddle/phi/kernels/declarations.h" +static std::string LegalizeVarName(const std::string& var_name) { + std::string ret = var_name; + std::replace(ret.begin(), ret.end(), '@', '_'); // replace all '-' to '_' + return ret; +} // clang-format off const char* OUT_INITIALIZER_TEMPLATE = R"({"%s", {std::shared_ptr(new imperative::VarBase("auto_"+std::to_string(VarBaseUniqueNameID++)+"_"))}})"; @@ -185,18 +190,19 @@ std::string GenerateOpFunctionsBody( continue; } const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE; - auto input_arg = - paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name)); + auto input_arg = paddle::string::Sprintf( + ARG_TEMPLATE, in_type, TempName(LegalizeVarName(in_name))); input_args += input_arg; input_args += ","; input_args_num++; const auto in_cast_type = input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; auto dispensable = input.dispensable() ? "true" : "false"; - ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, op_type, - in_name, arg_idx++, dispensable); + ins_cast_str += + paddle::string::Sprintf(in_cast_type, LegalizeVarName(in_name), op_type, + in_name, arg_idx++, dispensable); - call_api_str += in_name + ", "; + call_api_str += LegalizeVarName(in_name) + ", "; } if (!input_args.empty() && input_args.back() == ',') { @@ -224,7 +230,7 @@ std::string GenerateOpFunctionsBody( input_args += ","; } input_args += out_type; - input_args += out_name; + input_args += LegalizeVarName(out_name); input_args_num++; if (output.dispensable()) { @@ -237,18 +243,19 @@ std::string GenerateOpFunctionsBody( const auto out_template = output.duplicable() ? INPUT_LIST_INITIALIZER_TEMPLATE : INPUT_INITIALIZER_TEMPLATE; - outs_initializer += - paddle::string::Sprintf(out_template, out_name, out_name); + outs_initializer += paddle::string::Sprintf(out_template, out_name, + LegalizeVarName(out_name)); outs_initializer += ","; } const auto in_cast_type = output.duplicable() ? CAST_VAR_PTR_LIST_TEMPLATE : CAST_VAR_PTR_TEMPLATE; auto dispensable = output.dispensable() ? "true" : "false"; - ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, op_type, - out_name, arg_idx++, dispensable); + ins_cast_str += + paddle::string::Sprintf(in_cast_type, LegalizeVarName(out_name), + op_type, out_name, arg_idx++, dispensable); - call_api_str += out_name + ", "; + call_api_str += LegalizeVarName(out_name) + ", "; } else { // There are few Operators that have duplicable output, like `Out` in // split op. We need to specify the number of variables for the @@ -257,7 +264,8 @@ std::string GenerateOpFunctionsBody( if (input_args != "") { input_args += ","; } - auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name); + auto out_num_str = + paddle::string::Sprintf(ARG_OUT_NUM, LegalizeVarName(out_name)); input_args += ARG_OUT_NUM_TYPE; input_args += out_num_str; input_args_num++; diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 6bbaa147ace5574cb4f2b3d619a8c8d5a4965c9f..a9e286a6fa049260d449b5850134db8e3ff23bfd 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -35,6 +35,12 @@ // phi #include "paddle/phi/kernels/declarations.h" +static std::string LegalizeVarName(const std::string& var_name) { + std::string ret = var_name; + std::replace(ret.begin(), ret.end(), '@', '_'); // replace all '-' to '_' + return ret; +} + // NOTE(pangyoki): Inplace OP with duplicable input. // The set includes inplace ops that have duplicable input. // The first Varbase in input needs to be specified for the inplace strategy @@ -201,28 +207,31 @@ std::string GenerateOpFunctionsBody( continue; } const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE; - auto input_arg = - paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name)); + auto input_arg = paddle::string::Sprintf( + ARG_TEMPLATE, in_type, LegalizeVarName(TempName(in_name))); input_args += input_arg; input_args += ","; input_args_num++; const auto in_cast_type = input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; auto dispensable = input.dispensable() ? "true" : "false"; - ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, in_name, - arg_idx++, dispensable); + ins_cast_str += + paddle::string::Sprintf(in_cast_type, LegalizeVarName(in_name), in_name, + arg_idx++, dispensable); if (input.dispensable()) { const auto in_template = input.duplicable() ? INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST : INPUT_INITIALIZER_TEMPLATE_WITH_NULL; ins_initializer_with_null += - paddle::string::Sprintf(in_template, in_name, in_name, in_name); + paddle::string::Sprintf(in_template, LegalizeVarName(in_name), + in_name, LegalizeVarName(in_name)); } else { const auto in_template = input.duplicable() ? INPUT_LIST_INITIALIZER_TEMPLATE : INPUT_INITIALIZER_TEMPLATE; - ins_initializer += paddle::string::Sprintf(in_template, in_name, in_name); + ins_initializer += paddle::string::Sprintf(in_template, in_name, + LegalizeVarName(in_name)); ins_initializer += ","; } } @@ -259,7 +268,7 @@ std::string GenerateOpFunctionsBody( input_args += ","; } input_args += out_type; - input_args += out_name; + input_args += LegalizeVarName(out_name); input_args_num++; if (output.dispensable()) { @@ -272,16 +281,17 @@ std::string GenerateOpFunctionsBody( const auto out_template = output.duplicable() ? INPUT_LIST_INITIALIZER_TEMPLATE : INPUT_INITIALIZER_TEMPLATE; - outs_initializer += - paddle::string::Sprintf(out_template, out_name, out_name); + outs_initializer += paddle::string::Sprintf(out_template, out_name, + LegalizeVarName(out_name)); outs_initializer += ","; } const auto in_cast_type = output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; auto dispensable = output.dispensable() ? "true" : "false"; - ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, out_name, - arg_idx++, dispensable); + ins_cast_str += + paddle::string::Sprintf(in_cast_type, LegalizeVarName(out_name), + out_name, arg_idx++, dispensable); } else if (use_inplace_strategy && inplace_map.count(out_name)) { PADDLE_ENFORCE_NE( inplace_map[out_name], "", @@ -307,11 +317,13 @@ std::string GenerateOpFunctionsBody( // Leaf Var that doesn't stop gradient can't use inplace strategy. // Increase inplace_version. inplace_strategy_str += paddle::string::Sprintf( - INPLACE_STRATEGY_TEMPLATE, inplace_input_name, inplace_input_name, - INPLACE_LEAF_ERROR_MESSAGE, inplace_input_name, inplace_input_name, - inplace_input_name); - outs_initializer += - paddle::string::Sprintf(out_template, out_name, inplace_input_name); + INPLACE_STRATEGY_TEMPLATE, LegalizeVarName(inplace_input_name), + LegalizeVarName(inplace_input_name), INPLACE_LEAF_ERROR_MESSAGE, + LegalizeVarName(inplace_input_name), + LegalizeVarName(inplace_input_name), + LegalizeVarName(inplace_input_name)); + outs_initializer += paddle::string::Sprintf( + out_template, out_name, LegalizeVarName(inplace_input_name)); outs_initializer += ","; } else { // There are few Operators that have duplicable output, like `Out` in @@ -321,7 +333,8 @@ std::string GenerateOpFunctionsBody( if (input_args != "") { input_args += ","; } - auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name); + auto out_num_str = + paddle::string::Sprintf(ARG_OUT_NUM, LegalizeVarName(out_name)); input_args += ARG_OUT_NUM_TYPE; input_args += out_num_str; input_args_num++;