未验证 提交 d6e99fe4 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Adjust Yaml name parsing to satisfy Sparse-related APIs (#40480)

上级 481db5e9
......@@ -148,6 +148,12 @@ def ReadBwdFile(filepath):
######################
### Yaml Parsers ###
######################
def RemoveSpecialSymbolsInName(string):
# Remove any name after '@'
ret = string.split("@")[0]
return ret
def IntermediateValidationCheck(intermediate_outputs, forward_returns_list):
# intermediate_outputs : [name0, name1, ...]
# forward_returns_list : [[ret_name, type, orig_pos], ...]
......@@ -166,15 +172,19 @@ def IntermediateValidationCheck(intermediate_outputs, forward_returns_list):
def ParseDispensable(string):
# string: "X, Y"
string = RemoveSpecialSymbolsInName(string)
return [v.strip() for v in string.split(",")]
def ParseIntermediate(string):
string = RemoveSpecialSymbolsInName(string)
return [v.strip() for v in string.split(",")]
def ParseNoNeedBuffer(string):
# string: "x, y"
string = RemoveSpecialSymbolsInName(string)
no_need_buffer_set = set()
for name in string.split(","):
no_need_buffer_set.add(name.strip())
......@@ -204,6 +214,8 @@ def ParseYamlArgs(string):
assert arg_type in yaml_types_mapping.keys()
arg_type = yaml_types_mapping[arg_type]
arg_name = RemoveSpecialSymbolsInName(arg_name)
if "Tensor" in arg_type:
assert default_value is None
inputs_list.append([arg_name, arg_type, i])
......@@ -239,6 +251,7 @@ def ParseYamlReturns(string):
ret_type = yaml_types_mapping[ret_type]
assert "Tensor" in ret_type
ret_name = RemoveSpecialSymbolsInName(ret_name)
returns_list.append([ret_name, ret_type, i])
return returns_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册