diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 967891fe5227dcd6129c0ef1808fba7720711568..537c2bb7f02be4001ad8adbc3ca97133a677ef81 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -148,6 +148,12 @@ def ReadBwdFile(filepath): ###################### ### Yaml Parsers ### ###################### +def RemoveSpecialSymbolsInName(string): + # Remove any name after '@' + ret = string.split("@")[0] + return ret + + def IntermediateValidationCheck(intermediate_outputs, forward_returns_list): # intermediate_outputs : [name0, name1, ...] # forward_returns_list : [[ret_name, type, orig_pos], ...] @@ -166,15 +172,19 @@ def IntermediateValidationCheck(intermediate_outputs, forward_returns_list): def ParseDispensable(string): # string: "X, Y" + string = RemoveSpecialSymbolsInName(string) return [v.strip() for v in string.split(",")] def ParseIntermediate(string): + string = RemoveSpecialSymbolsInName(string) return [v.strip() for v in string.split(",")] def ParseNoNeedBuffer(string): # string: "x, y" + string = RemoveSpecialSymbolsInName(string) + no_need_buffer_set = set() for name in string.split(","): no_need_buffer_set.add(name.strip()) @@ -204,6 +214,8 @@ def ParseYamlArgs(string): assert arg_type in yaml_types_mapping.keys() arg_type = yaml_types_mapping[arg_type] + + arg_name = RemoveSpecialSymbolsInName(arg_name) if "Tensor" in arg_type: assert default_value is None inputs_list.append([arg_name, arg_type, i]) @@ -239,6 +251,7 @@ def ParseYamlReturns(string): ret_type = yaml_types_mapping[ret_type] assert "Tensor" in ret_type + ret_name = RemoveSpecialSymbolsInName(ret_name) returns_list.append([ret_name, ret_type, i]) return returns_list