eager_gen.py 4.0 KB
Newer Older
X
xiaoguoguo626807 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import yaml
from prim_base import EagerPrimAPI


def header_include():
    return """
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/place.h"
#include "paddle/utils/optional.h"
"""


31
def eager_source_include():
X
xiaoguoguo626807 已提交
32 33 34
    return """
#include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
35
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
X
xiaoguoguo626807 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
"""


def api_namespace():
    return (
        """
namespace paddle {
namespace prim {
""",
        """
using Tensor = paddle::experimental::Tensor;
using Scalar = paddle::experimental::Scalar;
using IntArray = paddle::experimental::IntArray;
using DataType = paddle::experimental::DataType;
""",
        """
}  // namespace prim
}  // namespace paddle
""",
    )


58 59 60
def generate_api(
    api_yaml_path, header_file_path, eager_prim_source_file_path, api_prim_path
):
X
xiaoguoguo626807 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
    apis = []

    for each_api_yaml in api_yaml_path:
        with open(each_api_yaml, 'r') as f:
            api_list = yaml.load(f, Loader=yaml.FullLoader)
            if api_list:
                apis.extend(api_list)

    header_file = open(header_file_path, 'w')
    eager_prim_source_file = open(eager_prim_source_file_path, 'w')

    namespace = api_namespace()

    header_file.write("#pragma once\n")
    header_file.write(header_include())
    header_file.write(namespace[0])
    header_file.write(namespace[1])
78
    eager_prim_source_file.write(eager_source_include())
X
xiaoguoguo626807 已提交
79 80
    eager_prim_source_file.write(namespace[0])

81 82 83
    with open(api_prim_path, 'rt') as f:
        api_prims = yaml.safe_load(f)

X
xiaoguoguo626807 已提交
84
    for api in apis:
85
        prim_api = EagerPrimAPI(api, api_prims)
X
xiaoguoguo626807 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
        if prim_api.is_prim_api:
            header_file.write(prim_api.gene_prim_api_declaration())
            eager_prim_source_file.write(prim_api.gene_eager_prim_api_code())

    header_file.write(namespace[2])
    eager_prim_source_file.write(namespace[2])

    header_file.close()
    eager_prim_source_file.close()


def main():
    parser = argparse.ArgumentParser(
        description='Generate PaddlePaddle C++ API files'
    )
    parser.add_argument(
        '--api_yaml_path',
        help='path to api yaml file',
        nargs='+',
        default=['paddle/phi/api/yaml/ops.yaml'],
    )

    parser.add_argument(
        '--prim_api_header_path',
        help='output of generated prim_api header code file',
111
        default='paddle/fluid/prim/api/generated_prim/prim_generated_api.h',
X
xiaoguoguo626807 已提交
112 113 114 115 116
    )

    parser.add_argument(
        '--eager_prim_api_source_path',
        help='output of generated eager_prim_api source code file',
117
        default='paddle/fluid/prim/api/generated_prim/eager_prim_api.cc',
X
xiaoguoguo626807 已提交
118 119
    )

120 121 122 123 124 125
    parser.add_argument(
        '--api_prim_yaml_path',
        help='Primitive API list yaml file.',
        default='paddle/fluid/prim/api/auto_code_generated/api.yaml',
    )

X
xiaoguoguo626807 已提交
126 127 128 129 130
    options = parser.parse_args()

    api_yaml_path = options.api_yaml_path
    prim_api_header_file_path = options.prim_api_header_path
    eager_prim_api_source_file_path = options.eager_prim_api_source_path
131
    api_prim_yaml_path = options.api_prim_yaml_path
X
xiaoguoguo626807 已提交
132 133 134 135 136

    generate_api(
        api_yaml_path,
        prim_api_header_file_path,
        eager_prim_api_source_file_path,
137
        api_prim_yaml_path,
X
xiaoguoguo626807 已提交
138 139 140 141 142
    )


if __name__ == '__main__':
    main()