From d6e99fe4eeb24efd445cfd1093c35dc43e6e0e15 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Mon, 14 Mar 2022 13:47:13 +0800 Subject: [PATCH] Adjust Yaml name parsing to satisfy Sparse-related APIs (#40480) --- .../final_state_generator/eager_gen.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 967891fe52..537c2bb7f0 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -148,6 +148,12 @@ def ReadBwdFile(filepath): ###################### ### Yaml Parsers ### ###################### +def RemoveSpecialSymbolsInName(string): + # Remove any name after '@' + ret = string.split("@")[0] + return ret + + def IntermediateValidationCheck(intermediate_outputs, forward_returns_list): # intermediate_outputs : [name0, name1, ...] # forward_returns_list : [[ret_name, type, orig_pos], ...] @@ -166,15 +172,19 @@ def IntermediateValidationCheck(intermediate_outputs, forward_returns_list): def ParseDispensable(string): # string: "X, Y" + string = RemoveSpecialSymbolsInName(string) return [v.strip() for v in string.split(",")] def ParseIntermediate(string): + string = RemoveSpecialSymbolsInName(string) return [v.strip() for v in string.split(",")] def ParseNoNeedBuffer(string): # string: "x, y" + string = RemoveSpecialSymbolsInName(string) + no_need_buffer_set = set() for name in string.split(","): no_need_buffer_set.add(name.strip()) @@ -204,6 +214,8 @@ def ParseYamlArgs(string): assert arg_type in yaml_types_mapping.keys() arg_type = yaml_types_mapping[arg_type] + + arg_name = RemoveSpecialSymbolsInName(arg_name) if "Tensor" in arg_type: assert default_value is None inputs_list.append([arg_name, arg_type, i]) @@ -239,6 +251,7 @@ def ParseYamlReturns(string): ret_type = yaml_types_mapping[ret_type] assert "Tensor" in ret_type + ret_name = RemoveSpecialSymbolsInName(ret_name) returns_list.append([ret_name, ret_type, i]) return returns_list -- GitLab