From d2a911b46be80a01ef685f0bbc2bdffb683316b1 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Fri, 4 Mar 2022 10:13:08 +0800 Subject: [PATCH] [Yaml]Support parsing fwd & bwd returns with name (#40107) --- .../final_state_generator/eager_gen.py | 41 +++++++------------ 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 65dbb0368c6..4945a6fb654 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -208,39 +208,26 @@ def ParseYamlArgs(string): def ParseYamlReturns(string): - # Example: Tensor, Tensor - - # list = [ ["", ret_type, orig_position], ...] - returns_list = [] - - returns = [x.strip() for x in string.strip().split(",")] - for i in range(len(returns)): - ret_type = returns[i] - - assert ret_type in yaml_types_mapping.keys() - ret_type = yaml_types_mapping[ret_type] - - returns_list.append(["", ret_type, i]) - - return returns_list - - -def ParseYamlReturnsWithName(string): - # Example: Tensor(out), Tensor(out1) + # Example0: Tensor(out), Tensor(out1) + # Example1: Tensor, Tensor + # Example2: Tensor[](out), Tensor # list = [ [ret_name, ret_type, orig_position], ...] returns_list = [] returns = [x.strip() for x in string.strip().split(",")] - atype = r'(.*?)' - aname = r'(.*?)' - pattern = f'{atype}\({aname}\)' for i in range(len(returns)): ret = returns[i] - m = re.search(pattern, ret) - ret_type = m.group(1) - ret_name = m.group(2) + + ret_name = "" + if "(" in ret and ")" in ret: + # Remove trailing ')' + ret = ret[:-1] + ret_type = ret.split("(")[0].strip() + ret_name = ret.split("(")[1].strip() + else: + ret_type = ret.strip() assert ret_type in yaml_types_mapping.keys() ret_type = yaml_types_mapping[ret_type] @@ -266,7 +253,7 @@ def ParseYamlForwardFromBackward(string): function_returns = m.group(3) forward_inputs_list, forward_attrs_list = ParseYamlArgs(function_args) - forward_returns_list = ParseYamlReturnsWithName(function_returns) + forward_returns_list = ParseYamlReturns(function_returns) return forward_inputs_list, forward_attrs_list, forward_returns_list @@ -296,7 +283,7 @@ def ParseYamlBackward(args_str, returns_str): args_str = re.search(args_pattern, args_str).group(1) inputs_list, attrs_list = ParseYamlArgs(args_str) - returns_list = ParseYamlReturnsWithName(returns_str) + returns_list = ParseYamlReturns(returns_str) return inputs_list, attrs_list, returns_list -- GitLab