未验证 提交 d6e99fe4 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Adjust Yaml name parsing to satisfy Sparse-related APIs (#40480)

上级 481db5e9
...@@ -148,6 +148,12 @@ def ReadBwdFile(filepath): ...@@ -148,6 +148,12 @@ def ReadBwdFile(filepath):
###################### ######################
### Yaml Parsers ### ### Yaml Parsers ###
###################### ######################
def RemoveSpecialSymbolsInName(string):
# Remove any name after '@'
ret = string.split("@")[0]
return ret
def IntermediateValidationCheck(intermediate_outputs, forward_returns_list): def IntermediateValidationCheck(intermediate_outputs, forward_returns_list):
# intermediate_outputs : [name0, name1, ...] # intermediate_outputs : [name0, name1, ...]
# forward_returns_list : [[ret_name, type, orig_pos], ...] # forward_returns_list : [[ret_name, type, orig_pos], ...]
...@@ -166,15 +172,19 @@ def IntermediateValidationCheck(intermediate_outputs, forward_returns_list): ...@@ -166,15 +172,19 @@ def IntermediateValidationCheck(intermediate_outputs, forward_returns_list):
def ParseDispensable(string): def ParseDispensable(string):
# string: "X, Y" # string: "X, Y"
string = RemoveSpecialSymbolsInName(string)
return [v.strip() for v in string.split(",")] return [v.strip() for v in string.split(",")]
def ParseIntermediate(string): def ParseIntermediate(string):
string = RemoveSpecialSymbolsInName(string)
return [v.strip() for v in string.split(",")] return [v.strip() for v in string.split(",")]
def ParseNoNeedBuffer(string): def ParseNoNeedBuffer(string):
# string: "x, y" # string: "x, y"
string = RemoveSpecialSymbolsInName(string)
no_need_buffer_set = set() no_need_buffer_set = set()
for name in string.split(","): for name in string.split(","):
no_need_buffer_set.add(name.strip()) no_need_buffer_set.add(name.strip())
...@@ -204,6 +214,8 @@ def ParseYamlArgs(string): ...@@ -204,6 +214,8 @@ def ParseYamlArgs(string):
assert arg_type in yaml_types_mapping.keys() assert arg_type in yaml_types_mapping.keys()
arg_type = yaml_types_mapping[arg_type] arg_type = yaml_types_mapping[arg_type]
arg_name = RemoveSpecialSymbolsInName(arg_name)
if "Tensor" in arg_type: if "Tensor" in arg_type:
assert default_value is None assert default_value is None
inputs_list.append([arg_name, arg_type, i]) inputs_list.append([arg_name, arg_type, i])
...@@ -239,6 +251,7 @@ def ParseYamlReturns(string): ...@@ -239,6 +251,7 @@ def ParseYamlReturns(string):
ret_type = yaml_types_mapping[ret_type] ret_type = yaml_types_mapping[ret_type]
assert "Tensor" in ret_type assert "Tensor" in ret_type
ret_name = RemoveSpecialSymbolsInName(ret_name)
returns_list.append([ret_name, ret_type, i]) returns_list.append([ret_name, ret_type, i])
return returns_list return returns_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册