未验证 提交 852a872f 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Added attr & tensor type mapping for final state codegen (#39997)

上级 72e462cd
...@@ -24,6 +24,17 @@ core_ops_args_info = {} ...@@ -24,6 +24,17 @@ core_ops_args_info = {}
core_ops_args_type_info = {} core_ops_args_type_info = {}
yaml_types_mapping = {
'int' : 'int', 'int32_t' : 'int32_t', 'int64_t' : 'int64_t', 'size_t' : 'size_t', \
'float' : 'float', 'double' : 'double', 'bool' : 'bool', \
'Backend' : 'Backend', 'DataLayout' : 'DataLayout', 'DataType' : 'DataType', \
'int64_t[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>',
'Tensor' : 'Tensor',
'Tensor[]' : 'std::vector<Tensor>',
'Tensor[Tensor[]]' : 'std::vector<std::vector<Tensor>>'
}
def ParseArguments(): def ParseArguments():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Eager Code Generator Args Parser') description='Eager Code Generator Args Parser')
...@@ -59,7 +70,9 @@ def IsPlainTensorType(string): ...@@ -59,7 +70,9 @@ def IsPlainTensorType(string):
def IsVectorTensorType(string): def IsVectorTensorType(string):
vector_tensor_types = ['list(Tensor)'] vector_tensor_types = [
'std::vector<std::vector<Tensor>>', 'std::vector<Tensor>'
]
if string in vector_tensor_types: if string in vector_tensor_types:
return True return True
return False return False
...@@ -180,6 +193,9 @@ def ParseYamlArgs(string): ...@@ -180,6 +193,9 @@ def ParseYamlArgs(string):
arg_name = m.group(3).split("=")[0].strip() arg_name = m.group(3).split("=")[0].strip()
default_value = m.group(3).split("=")[1].strip() if len( default_value = m.group(3).split("=")[1].strip() if len(
m.group(3).split("=")) > 1 else None m.group(3).split("=")) > 1 else None
assert arg_type in yaml_types_mapping.keys()
arg_type = yaml_types_mapping[arg_type]
if "Tensor" in arg_type: if "Tensor" in arg_type:
assert default_value is None assert default_value is None
inputs_list.append([arg_name, arg_type, i]) inputs_list.append([arg_name, arg_type, i])
...@@ -219,6 +235,10 @@ def ParseYamlReturnsWithName(string): ...@@ -219,6 +235,10 @@ def ParseYamlReturnsWithName(string):
m = re.search(pattern, ret) m = re.search(pattern, ret)
ret_type = m.group(1) ret_type = m.group(1)
ret_name = m.group(2) ret_name = m.group(2)
assert ret_type in yaml_types_mapping.keys()
ret_type = yaml_types_mapping[ret_type]
assert "Tensor" in ret_type assert "Tensor" in ret_type
returns_list.append([ret_name, ret_type, i]) returns_list.append([ret_name, ret_type, i])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册