generate_phi_kernel_dialect.py 12.5 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
16
import yaml
17 18 19 20 21
import os
from get_compat_kernel_signature import get_compat_kernels_info

#TODO @DannyIsFunny: more attr types need to be supported.
attr_type_converter = {
22 23 24 25 26 27
    "int": 'SI32Attr',
    "bool": 'BoolAttr',
    "int64_t": 'SI64Attr',
    "float": 'F32Attr',
    "string": 'StrAttr',
    "vector<int>": 'I32ArrayAttr'
28
}
29

30
target_type_converter = {"CPU": "CPU", "GPU": "GPU", "Undefined": "UNK"}
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
layout_type_converter = {
    "NCHW": "NCHW",
    "NHWC": "NHWC",
    "Undefined(AnyLayout)": "ANY"
}
precision_type_converter = {
    "uint8": "UINT8",
    "int8": "INT8",
    "int16": "INT16",
    "int32": "INT32",
    "int64": "INT64",
    "float16": "FLOAT16",
    "bfloat16": "BFLOAT16",
    "float32": "FLOAT32",
    "float64": "FLOAT64",
    "complex64": "COMPLEX64",
    "complex128": "COMPLEX128",
H
huzhiqiang 已提交
48 49
    "bool": "BOOL",
    "Undefined": "UNK"
50 51
}

52 53 54
kernel_types_info_file = "./kernels.json"
kernel_signature_info_file = "./kernel_signature.json"

55 56 57 58 59 60 61 62 63 64 65
skipped_phi_api_list_file = "./skipped_phi_api.json"


def get_skipped_kernel_list():
    skiped_kernel_list = []
    with open(skipped_phi_api_list_file, 'r') as f:
        skiped_api_list = json.load(f)
    infer_meta_data = get_api_yaml_info("../../")
    for api in infer_meta_data:
        if "kernel" not in api or "infer_meta" not in api:
            continue
66
        if api["op"] in skiped_api_list["phi_apis"]:
67 68 69 70 71 72
            skiped_kernel_list.append(api["kernel"]["func"])
    skiped_kernel_list += skiped_api_list["phi_kernels"]
    return skiped_kernel_list


def get_api_yaml_info(file_path):
Z
zyfncg 已提交
73
    apis = []
74
    with open(file_path + "/paddle/phi/api/yaml/api.yaml", 'r') as f:
Z
zyfncg 已提交
75 76 77
        api_list = yaml.load(f, Loader=yaml.FullLoader)
        if api_list:
            apis.extend(api_list)
78
    with open(file_path + "/paddle/phi/api/yaml/legacy_api.yaml", 'r') as f:
Z
zyfncg 已提交
79 80 81 82
        legacy_api_list = yaml.load(f, Loader=yaml.FullLoader)
        if legacy_api_list:
            apis.extend(legacy_api_list)
    return apis
83

84 85 86 87 88 89

def generate_kernel_name(op_name, place_str):
    [target_, layout_, precision_] = place_str[1:-1].split(',')
    target_ = target_type_converter[target_.strip()]
    layout_ = layout_type_converter[layout_.strip()]
    precision_ = precision_type_converter[precision_.strip()]
90 91
    class_name_ = "{}{}".format(
        op_name.replace("_", "").title(), "".join([
92 93 94
            target_.strip().title(),
            precision_.strip(),
            layout_.strip().title().title()
95
        ]))
96 97 98 99 100
    alias_ = "{}.{}".format(
        op_name,
        ".".join([target_.strip(),
                  precision_.strip(),
                  layout_.strip()]))
101
    return alias_, class_name_
102 103 104


def generate_attrs_info(op_name, attrs_info):
105
    kernel_attrs_names = {}
106
    attrs_args_ = ""
107 108 109 110
    with open(kernel_signature_info_file) as f:
        kernel_attrs_names = json.load(f)
        kernel_attrs_names.update(get_compat_kernels_info())
    if len(kernel_attrs_names[op_name]["attrs"]) == len(attrs_info):
111
        for index in range(len(attrs_info)):
112
            attr_name = kernel_attrs_names[op_name]["attrs"][index]
113
            attr_type = attr_type_converter[attrs_info[index]]
114 115
            attrs_args_ += '{type_}:${name_},'.format(type_=attr_type,
                                                      name_=attr_name)
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
    return attrs_args_[:-1]


def generate_inputs_info(input_info):
    input_args_ = ""
    for index in range(len(input_info)):
        [target_, layout_, precision_] = input_info[index].split(',')
        # todo: check vadility
        target_ = target_type_converter[target_.strip()]
        layout_ = layout_type_converter[layout_.strip()]
        precision_ = precision_type_converter[precision_.strip()]
        input_args_ += " DenseTensor<\"{}\",\"{}\",\"{}\">:$in{},".format(
            target_.strip(), precision_.strip(), layout_.strip(), str(index))
    input_args_ = input_args_[:-1]
    return input_args_


def generate_arguments_info(op_name, input_info, attr_info):
    input_args = generate_inputs_info(input_info)
    attr_args = generate_attrs_info(op_name, attr_info)
136
    context_args = "Context:$dev_ctx"
137 138
    argument_list = [context_args
                     ] + input_args.split(",") + attr_args.split(",")
139 140 141
    while ("" in argument_list):
        argument_list.remove("")
    argument_ = ",".join(argument_list)
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
    return (("let arguments = (ins {});".format(argument_.strip(","))))


def generate_results_info(output_info):
    output_args_ = "let results = (outs "
    for index in range(len(output_info)):
        [target_, layout_, precision_] = output_info[index].split(',')
        # todo: check vadility
        target_ = target_type_converter[target_.strip()]
        layout_ = layout_type_converter[layout_.strip()]
        precision_ = precision_type_converter[precision_.strip()]
        output_args_ += " DenseTensor<\"{}\",\"{}\",\"{}\">:$out{},".format(
            target_.strip(), precision_.strip(), layout_.strip(), str(index))
    return ("{});".format(output_args_[:-1]))


def generate_supported_kernel_list(load_dict):
    supported_kernels_list_ = []
160 161 162 163
    kernel_attrs_names = {}
    with open(kernel_signature_info_file) as f:
        kernel_attrs_names = json.load(f)
        kernel_attrs_names.update(get_compat_kernels_info())
164 165 166 167 168 169 170 171 172
    for op_name in load_dict:
        kernel_list = load_dict[op_name]
        for kernel_info in kernel_list:
            for kernel_alias_ in kernel_info:
                attributes = kernel_info[kernel_alias_]["attribute"]
                flag = True
                for attribute in attributes:
                    if attribute not in attr_type_converter:
                        flag = False
173
                if flag and op_name in kernel_attrs_names:
174 175
                    supported_kernels_list_.append(op_name)
    supported_kernels_list_ = list(set(supported_kernels_list_))
176 177 178 179
    skipped_kernel_list = get_skipped_kernel_list()
    for skipped_kernel in skipped_kernel_list:
        if skipped_kernel in skipped_kernel_list:
            supported_kernels_list_.remove(skipped_kernel)
180
    return supported_kernels_list_
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204


def scan_kernel_info(load_dict):
    target_type_ = []
    layout_type_ = []
    precision_type_ = []
    for op_name in load_dict:
        kernel_list = load_dict[op_name]
        for kernel_info in kernel_list:
            for kernel_alias_ in kernel_info:
                [target_, layout_, precision_] = kernel_alias_[1:-1].split(',')
                target_type_.append(target_.strip())
                layout_type_.append(layout_.strip())
                precision_type_.append(precision_.strip())
    target_type_ = list(set(target_type_))
    layout_type_ = list(set(layout_type_))
    precision_type_ = list(set(precision_type_))
    print(target_type_)
    print(layout_type_)
    print(precision_type_)


def generate_cpu_kernel_dialect(op_name, kernel_alias_, kernel_info):

205
    alias, class_name = generate_kernel_name(op_name, kernel_alias_)
206 207 208 209 210 211
    summary = 'let summary = "{name}";'.format(name=alias)
    dialect_name = alias.split(".")
    dialect_name = dialect_name[0] + "." + dialect_name[2] + "." + dialect_name[
        3]

    header = 'def {kernel_name} : PDTCPU_Kernel<"{name}",[NoSideEffect]> {left_brace}'.format(
212
        kernel_name=class_name, name=dialect_name.lower(), left_brace="{")
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231

    inputs_ = kernel_info["input"]
    attributes = kernel_info["attribute"]
    arguments = generate_arguments_info(op_name, inputs_, attributes)

    outputs = kernel_info["output"]
    results = generate_results_info(outputs)

    kernel_dialect = '{header_}\n  {summary_}\n  {arguments_}\n  {results_}\n{right_brace}\n'.format(
        header_=header,
        summary_=summary,
        arguments_=arguments,
        results_=results,
        right_brace="}")
    return kernel_dialect


def generate_gpu_kernel_dialect(op_name, kernel_alias_, kernel_info):

232
    alias, class_name = generate_kernel_name(op_name, kernel_alias_)
233 234 235 236 237 238
    summary = 'let summary = "{name}";'.format(name=alias)
    dialect_name = alias.split(".")
    dialect_name = dialect_name[0] + "." + dialect_name[2] + "." + dialect_name[
        3]

    header = 'def {kernel_name} : PDTGPU_Kernel<"{name}",[NoSideEffect]> {left_brace}'.format(
239
        kernel_name=class_name, name=dialect_name.lower(), left_brace="{")
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
    inputs_ = kernel_info["input"]
    attributes = kernel_info["attribute"]
    arguments = generate_arguments_info(op_name, inputs_, attributes)

    outputs = kernel_info["output"]
    results = generate_results_info(outputs)

    kernel_dialect = '{header_}\n  {summary_}\n  {arguments_}\n  {results_}\n{right_brace}\n'.format(
        header_=header,
        summary_=summary,
        arguments_=arguments,
        results_=results,
        right_brace="}")
    return kernel_dialect


def generate_dialect_head():
    comment_ = "/*===- TableGen'source file -----------------------------------------------===*\\\n\
|*                                                                            *|\n\
|* Kernel Definitions                                                         *|\n\
|*                                                                            *|\n\
|* Automatically generated file, do not edit!                                 *|\n\
|* Generated by tools/infrt/generate_pten_kernel_dialect.py                   *|\n\
|*                                                                            *|\n\
264
\\*===----------------------------------------------------------------------===*/\n"
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280

    includes_ = "#ifndef PTEN_KERNELS\n\
#define PTEN_KERNELS\n\
include \"mlir/Interfaces/InferTypeOpInterface.td\"\n\
include \"mlir/Interfaces/LoopLikeInterface.td\"\n\
include \"mlir/IR/OpBase.td\"\n\
include \"paddle/infrt/dialect/phi/ir/infrt_phi_kernel.td\""

    return (comment_ + includes_)


def get_kernel_target(kernel_alias_):
    target = kernel_alias_[1:-1].split(",")
    return target[0]


281 282
def main():
    with open(kernel_types_info_file, "r") as f:
283 284 285 286 287 288
        load_dict = json.load(f)

        head = generate_dialect_head()

        cpu_registry_ = ""
        gpu_registry_ = ""
289
        supported_kernels = generate_supported_kernel_list(load_dict)
290

291 292
        print("Supported kernels:")
        print(supported_kernels)
293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
        for op_name in load_dict:
            if op_name not in supported_kernels:
                continue
            kernel_list = load_dict[op_name]
            for kernel_info in kernel_list:
                for kernel_alias_ in kernel_info:
                    if get_kernel_target(kernel_alias_) == "CPU":
                        kernel_registry = generate_cpu_kernel_dialect(
                            op_name, kernel_alias_, kernel_info[kernel_alias_])
                        cpu_registry_ += kernel_registry
                    elif get_kernel_target(kernel_alias_) == "GPU":
                        kernel_registry = generate_gpu_kernel_dialect(
                            op_name, kernel_alias_, kernel_info[kernel_alias_])
                        gpu_registry_ += kernel_registry
                    else:
308 309
                        print("Unsupported backend:" +
                              get_kernel_target(kernel_alias_))
310 311 312 313 314 315 316 317 318 319 320 321
        end = "#endif  // PTEN_KERNELS"
        with open("../../paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td",
                  "w") as dst:
            dst.write('{start_}\n{dialect_}\n{end_}'.format(
                start_=head, dialect_=cpu_registry_, end_=end))
        with open("../../paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td",
                  "w") as dst:
            dst.write('{start_}\n{dialect_}\n{end_}'.format(
                start_=head, dialect_=gpu_registry_, end_=end))


if __name__ == '__main__':
322 323 324 325 326 327 328 329 330
    if not os.path.exists(kernel_types_info_file):
        print("Error: '{file_name}' not exist!".format(
            file_name=kernel_types_info_file))
    if not os.path.exists(kernel_signature_info_file):
        print("Error: '{file_name}' not exist!".format(
            file_name=kernel_signature_info_file))
    if os.path.exists(kernel_types_info_file) and os.path.exists(
            kernel_signature_info_file):
        main()