diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index c6e56e34627a52bc19df7e8d87371811fcec8697..02183e2ca5ce9f0996017eb7df59ee716b0f1ae2 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -24,6 +24,17 @@ core_ops_args_info = {} core_ops_args_type_info = {} +yaml_types_mapping = { + 'int' : 'int', 'int32_t' : 'int32_t', 'int64_t' : 'int64_t', 'size_t' : 'size_t', \ + 'float' : 'float', 'double' : 'double', 'bool' : 'bool', \ + 'Backend' : 'Backend', 'DataLayout' : 'DataLayout', 'DataType' : 'DataType', \ + 'int64_t[]' : 'std::vector', 'int[]' : 'std::vector', + 'Tensor' : 'Tensor', + 'Tensor[]' : 'std::vector', + 'Tensor[Tensor[]]' : 'std::vector>' +} + + def ParseArguments(): parser = argparse.ArgumentParser( description='Eager Code Generator Args Parser') @@ -59,7 +70,9 @@ def IsPlainTensorType(string): def IsVectorTensorType(string): - vector_tensor_types = ['list(Tensor)'] + vector_tensor_types = [ + 'std::vector>', 'std::vector' + ] if string in vector_tensor_types: return True return False @@ -180,6 +193,9 @@ def ParseYamlArgs(string): arg_name = m.group(3).split("=")[0].strip() default_value = m.group(3).split("=")[1].strip() if len( m.group(3).split("=")) > 1 else None + + assert arg_type in yaml_types_mapping.keys() + arg_type = yaml_types_mapping[arg_type] if "Tensor" in arg_type: assert default_value is None inputs_list.append([arg_name, arg_type, i]) @@ -219,6 +235,10 @@ def ParseYamlReturnsWithName(string): m = re.search(pattern, ret) ret_type = m.group(1) ret_name = m.group(2) + + assert ret_type in yaml_types_mapping.keys() + ret_type = yaml_types_mapping[ret_type] + assert "Tensor" in ret_type returns_list.append([ret_name, ret_type, i])