parse_kernel_registry.py 2.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import logging

ops_list_path = sys.argv[1]
dest_path = sys.argv[2]

out_lines = [
    '#pragma once',
    '#include "paddle_lite_factory_helper.h"',
    '',
]

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
left_pattern = 'REGISTER_LITE_KERNEL('
right_pattern = ')'

def find_right_pattern(context, start):
   if start >= len(context): return -1
   fake_left_num = 0
   while start < len(context):
       if context[start] == right_pattern:
           if fake_left_num == 0:
               return start
           else:
               fake_left_num -= 1
       elif context[start] == '(':
           fake_left_num += 1
       start += 1
   return -1

lines = set()
45 46
with open(ops_list_path) as f:
    for line in f:
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
        lines.add(line.strip())
    
for line in lines:
    path = line.strip()

    status = ''
    with open(path) as g:
        context = ''.join([item.strip() for item in g])
        index = 0
        cxt_len = len(context)
        while index < cxt_len and index >= 0:
            left_index = context.find(left_pattern, index)
            if left_index < 0: break
            right_index = find_right_pattern(context, left_index+len(left_pattern))
            if right_index < 0:
                raise ValueError("Left Pattern and Right Pattern does not match")
            tmp = context[left_index+len(left_pattern) : right_index]
            index = right_index + 1
            if tmp.startswith('/'): continue
            fields = [item.strip() for item in tmp.split(',')]
            if len(fields) < 6:
                raise ValueError("Invalid REGISTER_LITE_KERNEL format")

            op, target, precision, layout = fields[:4] 
            alias = fields[-1]
            key = "USE_LITE_KERNEL(%s, %s, %s, %s, %s);" % (
                op, target, precision, layout, alias)
            out_lines.append(key)

76 77 78 79

with open(dest_path, 'w') as f:
    logging.info("write kernel list to %s" % dest_path)
    f.write('\n'.join(out_lines))