From 852a872f6dafb3f8f32b30567d8402651f8e9e1e Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Tue, 1 Mar 2022 21:00:59 +0800 Subject: [PATCH] Added attr & tensor type mapping for final state codegen (#39997) --- .../final_state_generator/eager_gen.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index c6e56e34627..02183e2ca5c 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -24,6 +24,17 @@ core_ops_args_info = {} core_ops_args_type_info = {} +yaml_types_mapping = { + 'int' : 'int', 'int32_t' : 'int32_t', 'int64_t' : 'int64_t', 'size_t' : 'size_t', \ + 'float' : 'float', 'double' : 'double', 'bool' : 'bool', \ + 'Backend' : 'Backend', 'DataLayout' : 'DataLayout', 'DataType' : 'DataType', \ + 'int64_t[]' : 'std::vector', 'int[]' : 'std::vector', + 'Tensor' : 'Tensor', + 'Tensor[]' : 'std::vector', + 'Tensor[Tensor[]]' : 'std::vector>' +} + + def ParseArguments(): parser = argparse.ArgumentParser( description='Eager Code Generator Args Parser') @@ -59,7 +70,9 @@ def IsPlainTensorType(string): def IsVectorTensorType(string): - vector_tensor_types = ['list(Tensor)'] + vector_tensor_types = [ + 'std::vector>', 'std::vector' + ] if string in vector_tensor_types: return True return False @@ -180,6 +193,9 @@ def ParseYamlArgs(string): arg_name = m.group(3).split("=")[0].strip() default_value = m.group(3).split("=")[1].strip() if len( m.group(3).split("=")) > 1 else None + + assert arg_type in yaml_types_mapping.keys() + arg_type = yaml_types_mapping[arg_type] if "Tensor" in arg_type: assert default_value is None inputs_list.append([arg_name, arg_type, i]) @@ -219,6 +235,10 @@ def ParseYamlReturnsWithName(string): m = re.search(pattern, ret) ret_type = m.group(1) ret_name = m.group(2) + + assert ret_type in yaml_types_mapping.keys() + ret_type = yaml_types_mapping[ret_type] + assert "Tensor" in ret_type returns_list.append([ret_name, ret_type, i]) -- GitLab