未验证 提交 d2a911b4 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[Yaml]Support parsing fwd & bwd returns with name (#40107)

上级 73a4fe6c
...@@ -208,39 +208,26 @@ def ParseYamlArgs(string): ...@@ -208,39 +208,26 @@ def ParseYamlArgs(string):
def ParseYamlReturns(string): def ParseYamlReturns(string):
# Example: Tensor, Tensor # Example0: Tensor(out), Tensor(out1)
# Example1: Tensor, Tensor
# list = [ ["", ret_type, orig_position], ...] # Example2: Tensor[](out), Tensor
returns_list = []
returns = [x.strip() for x in string.strip().split(",")]
for i in range(len(returns)):
ret_type = returns[i]
assert ret_type in yaml_types_mapping.keys()
ret_type = yaml_types_mapping[ret_type]
returns_list.append(["", ret_type, i])
return returns_list
def ParseYamlReturnsWithName(string):
# Example: Tensor(out), Tensor(out1)
# list = [ [ret_name, ret_type, orig_position], ...] # list = [ [ret_name, ret_type, orig_position], ...]
returns_list = [] returns_list = []
returns = [x.strip() for x in string.strip().split(",")] returns = [x.strip() for x in string.strip().split(",")]
atype = r'(.*?)'
aname = r'(.*?)'
pattern = f'{atype}\({aname}\)'
for i in range(len(returns)): for i in range(len(returns)):
ret = returns[i] ret = returns[i]
m = re.search(pattern, ret)
ret_type = m.group(1) ret_name = ""
ret_name = m.group(2) if "(" in ret and ")" in ret:
# Remove trailing ')'
ret = ret[:-1]
ret_type = ret.split("(")[0].strip()
ret_name = ret.split("(")[1].strip()
else:
ret_type = ret.strip()
assert ret_type in yaml_types_mapping.keys() assert ret_type in yaml_types_mapping.keys()
ret_type = yaml_types_mapping[ret_type] ret_type = yaml_types_mapping[ret_type]
...@@ -266,7 +253,7 @@ def ParseYamlForwardFromBackward(string): ...@@ -266,7 +253,7 @@ def ParseYamlForwardFromBackward(string):
function_returns = m.group(3) function_returns = m.group(3)
forward_inputs_list, forward_attrs_list = ParseYamlArgs(function_args) forward_inputs_list, forward_attrs_list = ParseYamlArgs(function_args)
forward_returns_list = ParseYamlReturnsWithName(function_returns) forward_returns_list = ParseYamlReturns(function_returns)
return forward_inputs_list, forward_attrs_list, forward_returns_list return forward_inputs_list, forward_attrs_list, forward_returns_list
...@@ -296,7 +283,7 @@ def ParseYamlBackward(args_str, returns_str): ...@@ -296,7 +283,7 @@ def ParseYamlBackward(args_str, returns_str):
args_str = re.search(args_pattern, args_str).group(1) args_str = re.search(args_pattern, args_str).group(1)
inputs_list, attrs_list = ParseYamlArgs(args_str) inputs_list, attrs_list = ParseYamlArgs(args_str)
returns_list = ParseYamlReturnsWithName(returns_str) returns_list = ParseYamlReturns(returns_str)
return inputs_list, attrs_list, returns_list return inputs_list, attrs_list, returns_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册