get_phi_kernel_info.py 11.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#!/bin/python

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import yaml
W
Wilber 已提交
20
from typing import List, Dict, Any
21

22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
skipped_phi_api_list_file = "/tools/infrt/skipped_phi_api.json"
api_yaml_file = "/python/paddle/utils/code_gen/api.yaml"


def get_skipped_kernel_list():
    skiped_kernel_list = []
    with open(skipped_phi_api_list_file, 'r') as f:
        skiped_api_list = json.load(f)
    infer_meta_data = get_api_yaml_info(api_yaml_file)
    for api in infer_meta_data:
        if "kernel" not in api or "infer_meta" not in api:
            continue
        if api["api"] in skiped_api_list["phi_apis"]:
            skiped_kernel_list.append(api["kernel"]["func"])
    skiped_kernel_list += skiped_api_list["phi_kernels"]
    return skiped_kernel_list

39 40

def parse_args():
41
    parser = argparse.ArgumentParser("gather phi kernel and infermate info")
42 43 44 45
    parser.add_argument(
        "--paddle_root_path",
        type=str,
        required=True,
W
Wilber 已提交
46
        help="root path of paddle src[WORK_PATH/Paddle].")
47 48 49 50
    parser.add_argument(
        "--kernel_info_file",
        type=str,
        required=True,
51
        help="kernel info file generated by get_phi_kernel_function.sh.")
52 53 54 55
    parser.add_argument(
        "--infermeta_wrap_file",
        type=str,
        required=True,
W
Wilber 已提交
56
        help="inferMeta wrap info file.")
W
Wilber 已提交
57 58
    parser.add_argument(
        "--attr_info_file", type=str, required=True, help="attr info file.")
W
Wilber 已提交
59 60 61 62
    parser.add_argument(
        "--generate_file",
        type=str,
        required=True,
63
        default="../paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.cc",
W
Wilber 已提交
64
        help="generated file.")
65 66 67 68 69
    args = parser.parse_args()
    return args


def get_api_yaml_info(file_path):
70
    f = open(file_path, "r")
71 72 73 74 75
    cont = f.read()
    return yaml.load(cont, Loader=yaml.FullLoader)


def get_kernel_info(file_path):
W
Wilber 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
    f = open(file_path, "r")
    cont = f.readlines()
    ret = []
    prev = []
    for line in cont:
        info = line.strip().split()
        if not info:
            continue

        if len(prev) == 0:
            ret.append(line.strip())
            prev = info
            continue

        if prev[0] == info[0] and prev[1] == info[1]:
            ret.pop()
        ret.append(line.strip())
        prev = info
    return ret


def get_infermeta_info(file_path):
98 99
    f = open(file_path, "r")
    cont = f.readlines()
100
    return [l.strip() for l in cont if l.strip() != ""]
101 102


W
Wilber 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
def get_attr_info(file_path):
    """
    phi_gpu.argsort.float64.any $axisBool$descending
    """
    ret = {}
    with open(file_path, 'r') as f:
        cont = f.readlines()
        for l in cont:
            datas = l.strip().split(' ')
            if len(datas) == 2:
                attrs = datas[1].split('$')
                ret[datas[0]] = attrs[1:]
            else:
                ret[datas[0]] = None
    return ret


120
def merge(infer_meta_data, kernel_data, wrap_data):
121 122
    meta_map = {}
    for api in infer_meta_data:
123
        if "kernel" not in api or "infer_meta" not in api:
124 125
            continue
        meta_map[api["kernel"]["func"]] = api["infer_meta"]["func"]
126 127 128 129
    wrap_map = {}
    for l in wrap_data:
        wrap_map[l.split()[0]] = l.split()[1]

130 131 132
    full_kernel_data = []
    for l in kernel_data:
        key = l.split()[0]
133 134 135 136
        if key in wrap_map:
            full_kernel_data.append((l + " " + wrap_map[key]).split())
        elif key in meta_map:
            full_kernel_data.append((l + " " + meta_map[key]).split())
137 138 139 140 141 142
        else:
            full_kernel_data.append((l + " unknown").split())

    return full_kernel_data


W
Wilber 已提交
143
def gen_warn_info():
144
    return """// Generated by tools/infrt/gen_phi_kernel_register.py for infrt.
W
Wilber 已提交
145 146 147 148 149 150
// DO NOT edit or include it within paddle.
"""


def gen_include_headers():
    return """
151 152
#include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/kernel/phi/infershaped/phi_kernel_launcher.h"
153 154 155 156
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/include/kernels.h"
#include "paddle/phi/include/infermeta.h"
#include "paddle/phi/infermeta/generated.h"
W
Wilber 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
"""


def gen_namespace():
    return ("""
namespace infrt {
namespace kernel {

""", """

}  // namespace kernel
}  // namespace infrt
""")


def gen_context(val):
    if val == "CPU":
174
        return "::phi::CPUContext", "phi_cpu"
W
Wilber 已提交
175
    elif val == "GPU":
176
        return "::phi::GPUContext", "phi_gpu"
W
Wilber 已提交
177
    # elif val == "XPU":
178
    #     return "::phi::XPUContext", "phi_xpu"
W
Wilber 已提交
179 180
    else:
        # raise Exception(f"Unknown context type {val}")
W
Wilber 已提交
181
        return "", ""
W
Wilber 已提交
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197


def gen_layout(val):
    if val == "ALL_LAYOUT":
        return 'any'
    else:
        # TODO(wilber): now only process ALL_LAYOUT
        raise Exception(f"Unknown layout type {val}")


def gen_kernel_func(val, ctx_name, dtype_name):
    if '<' in val and '>' in val:
        st = val.index('<')
        ed = val.index('>')
        func_name = val[:st]
        template_name = val[st + 1:ed]
198 199
        if '::phi::' in template_name:
            return "&::phi::" + val
W
Wilber 已提交
200
        else:
201
            return "&::phi::" + func_name + "<::phi::" + template_name + ">"
W
Wilber 已提交
202
    else:
203
        return "&::phi::" + val + "<" + dtype_name + ", " + ctx_name + ">"
W
Wilber 已提交
204 205 206 207 208 209


def gen_dtype(vals: List[str]):
    ir_dtypes, origin_dtypes = [], []
    for val in vals:
        if val == "float":
210
            ir_dtypes.append("float32")
W
Wilber 已提交
211 212
            origin_dtypes.append("float")
        elif val == "double":
213
            ir_dtypes.append("float64")
W
Wilber 已提交
214 215
            origin_dtypes.append("double")
        elif val == "float16":
216
            ir_dtypes.append("float16")
W
Wilber 已提交
217 218 219 220 221
            origin_dtypes.append("paddle::experimental::float16")
        elif val == "bfloat16":
            ir_dtypes.append("bf16")
            origin_dtypes.append("paddle::experimental::bfloat16")
        elif val == "bool":
222
            ir_dtypes.append("bool")
W
Wilber 已提交
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
            origin_dtypes.append("bool")
        elif val == "int8_t":
            ir_dtypes.append("int8")
            origin_dtypes.append("int8_t")
        elif val == "uint8_t":
            ir_dtypes.append("uint8")
            origin_dtypes.append("uint8_t")
        elif val == "int16_t":
            ir_dtypes.append("int16")
            origin_dtypes.append("int16_t")
        elif val == "int" or val == "int32_t":
            ir_dtypes.append("int32")
            origin_dtypes.append("int32_t")
        elif val == "int64_t":
            ir_dtypes.append("int64")
            origin_dtypes.append("int64_t")
        elif val == "complex<float>" or val == "complex64":
            ir_dtypes.append("complex64")
            origin_dtypes.append("paddle::experimental::complex64")
        elif val == "complex<double>" or val == "complex128":
            ir_dtypes.append("complex128")
            origin_dtypes.append("paddle::experimental::complex128")
J
Jack Zhou 已提交
245 246 247
        elif val == "pstring":
            ir_dtypes.append("pstring")
            origin_dtypes.append("paddle::experimental::pstring")
W
Wilber 已提交
248 249 250 251 252 253 254 255 256 257
        elif val == "ALL_DTYPE":
            ir_dtypes.append("all")
            origin_dtypes.append("all")
        else:
            if "VA_ARGS" in val:
                continue
            raise Exception(f"Unknown data type {val}")
    return ir_dtypes, origin_dtypes


W
Wilber 已提交
258 259 260 261
# Note: Now only process CPUContext and GPUContext.


def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]):
W
Wilber 已提交
262
    """
W
Wilber 已提交
263 264
    item: ['add', 'CPU', 'ALL_LAYOUT', 'AddKernel', 'float', 'double', '...'(varaidic types), 'ElementwiseInferMeta']
    attr_data: {'phi_cpu.arg_min.float32.any': ['axisBool', 'keepdimsBool', 'flatten', 'dtype']}
W
Wilber 已提交
265
    """
W
Wilber 已提交
266 267 268 269 270
    ctx_name, ir_ctx_name = gen_context(item[1])
    if (ctx_name == ""):
        return ""
    item[2] = gen_layout(item[2])
    ir_dtypes, origin_dtypes = gen_dtype(item[4:-1])
271
    infer_shape_func = "&::phi::" + item[-1]
W
Wilber 已提交
272

W
Wilber 已提交
273
    res = ""
W
Wilber 已提交
274

W
Wilber 已提交
275 276 277 278
    if item[-1] == "unknown":
        # TODO(wilber): handle the unknown inferShape func.
        return ""

279
    skipped_kernel_list = get_skipped_kernel_list()
W
Wilber 已提交
280 281
    for ir_dtype, origin_dtype in zip(ir_dtypes, origin_dtypes):
        kernel_func = gen_kernel_func(item[3], ctx_name, origin_dtype)
282 283
        if item[0].lower() in skipped_kernel_list:
            continue
W
Wilber 已提交
284 285 286 287 288 289
        ir_name = ir_ctx_name + '.' + item[0].lower(
        ) + '.' + ir_dtype + '.' + item[2].lower()
        if ir_name in attr_data.keys() and attr_data[ir_name] is not None:
            attr_names = ', '.join(
                ["\"" + a + "\"" for a in attr_data[ir_name]])
            res += f"""
290
registry->AddKernel("{ir_name}","""
W
Wilber 已提交
291 292

            res += f"""
293
    &KernelLauncherFunc<decltype({kernel_func}),
W
Wilber 已提交
294 295 296 297 298
                                  {kernel_func},
                                  decltype({infer_shape_func}),
                                  {infer_shape_func}>,
    {{{attr_names}}});
"""
W
Wilber 已提交
299

W
Wilber 已提交
300
        else:
W
Wilber 已提交
301
            res += f"""
W
Wilber 已提交
302
registry->AddKernel("{ir_name}","""
W
Wilber 已提交
303 304

            res += f"""
305
    &KernelLauncherFunc<decltype({kernel_func}),
W
Wilber 已提交
306 307
                                  {kernel_func},
                                  decltype({infer_shape_func}),
308
                                  {infer_shape_func}>);
W
Wilber 已提交
309 310
"""

W
Wilber 已提交
311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
    return res


def gen_register_info(resources: List[List[str]],
                      attr_data: Dict[str, List[str]]):
    """
    resources: [['add', 'CPU', 'ALL_LAYOUT', 'AddKernel', 'float', 'double', '...'(varaidic types), 'ElementwiseInferMeta'], ...]
    attr_data: {'phi_cpu.arg_min.float32.any': ['axisBool', 'keepdimsBool', 'flatten', 'dtype']}
    """
    res = "void RegisterInferShapeLaunchers(host_context::KernelRegistry* registry) {"

    # register cpu kernels.
    for item in resources:
        # The output string is polluted by C++ macros, here the \ is removed
        update_item = [v.strip('\\') for v in item]
        if update_item[1] != "CPU":
            continue
        code = gen_register_code_info(item, attr_data)
        if (code == ""):
            continue
        res += code

    # register gpu kernels.
    res += "\n#ifdef INFRT_WITH_GPU"
    for item in resources:
        # The output string is polluted by C++ macros, here the \ is removed
        update_item = [v.strip('\\') for v in item]
        if update_item[1] != "GPU":
            continue
        code = gen_register_code_info(item, attr_data)
        if (code == ""):
            continue
        res += code
    res += "#endif // INFRT_WITH_GPU"

W
Wilber 已提交
346 347 348 349
    res += "\n}"
    return res


350
def gen_phi_kernel_register_code(resources: List[List[str]],
W
Wilber 已提交
351
                                 attr_data: Dict[str, List[str]],
352
                                 src_file_path: str):
W
Wilber 已提交
353 354 355 356 357
    source_file = open(src_file_path, 'w')
    source_file.write(gen_warn_info())
    source_file.write(gen_include_headers())
    namespace = gen_namespace()
    source_file.write(namespace[0])
W
Wilber 已提交
358
    source_file.write(gen_register_info(resources, attr_data))
W
Wilber 已提交
359 360 361 362
    source_file.write(namespace[1])
    source_file.close()


363 364
if __name__ == "__main__":
    args = parse_args()
365 366 367
    skipped_phi_api_list_file = args.paddle_root_path + skipped_phi_api_list_file
    api_yaml_file = args.paddle_root_path + api_yaml_file
    infer_meta_data = get_api_yaml_info(api_yaml_file)
368
    kernel_data = get_kernel_info(args.kernel_info_file)
W
Wilber 已提交
369
    info_meta_wrap_data = get_infermeta_info(args.infermeta_wrap_file)
W
Wilber 已提交
370
    attr_data = get_attr_info(args.attr_info_file)
371
    out = merge(infer_meta_data, kernel_data, info_meta_wrap_data)
W
Wilber 已提交
372
    gen_phi_kernel_register_code(out, attr_data, args.generate_file)