diff --git a/lite/tools/cmake_tools/parse_kernel_registry.py b/lite/tools/cmake_tools/parse_kernel_registry.py index 99804748f3780990194c429b050d364e3fa20b53..a0a123898bec18594ae12bfd1584cdd526cb1a33 100644 --- a/lite/tools/cmake_tools/parse_kernel_registry.py +++ b/lite/tools/cmake_tools/parse_kernel_registry.py @@ -24,35 +24,55 @@ out_lines = [ '', ] +left_pattern = 'REGISTER_LITE_KERNEL(' +right_pattern = ')' + +def find_right_pattern(context, start): + if start >= len(context): return -1 + fake_left_num = 0 + while start < len(context): + if context[start] == right_pattern: + if fake_left_num == 0: + return start + else: + fake_left_num -= 1 + elif context[start] == '(': + fake_left_num += 1 + start += 1 + return -1 + +lines = set() with open(ops_list_path) as f: for line in f: - path = line.strip() - - status = '' - with open(path) as g: - lines = [v for v in g] - for i in range(len(lines)): - line = lines[i].strip() - - if not status: - key = 'REGISTER_LITE_KERNEL' - if line.startswith(key): - forward = i + min(7, len(lines) - i) - remaining = line[len(key) + 1:] + ' '.join( - [v.strip() for v in lines[i + 1:forward]]) - - x = remaining.find('.') - if x > 0: - remaining = remaining[:x] - - fs = [v.strip() for v in remaining.split(',')] - assert (len(fs) >= 4) - op, target, precision, layout, __, alias = fs[:6] - alias = alias.replace(')', '') - - key = "USE_LITE_KERNEL(%s, %s, %s, %s, %s);" % ( - op, target, precision, layout, alias) - out_lines.append(key) + lines.add(line.strip()) + +for line in lines: + path = line.strip() + + status = '' + with open(path) as g: + context = ''.join([item.strip() for item in g]) + index = 0 + cxt_len = len(context) + while index < cxt_len and index >= 0: + left_index = context.find(left_pattern, index) + if left_index < 0: break + right_index = find_right_pattern(context, left_index+len(left_pattern)) + if right_index < 0: + raise ValueError("Left Pattern and Right Pattern does not match") + tmp = context[left_index+len(left_pattern) : right_index] + index = right_index + 1 + if tmp.startswith('/'): continue + fields = [item.strip() for item in tmp.split(',')] + if len(fields) < 6: + raise ValueError("Invalid REGISTER_LITE_KERNEL format") + + op, target, precision, layout = fields[:4] + alias = fields[-1] + key = "USE_LITE_KERNEL(%s, %s, %s, %s, %s);" % ( + op, target, precision, layout, alias) + out_lines.append(key) + with open(dest_path, 'w') as f: logging.info("write kernel list to %s" % dest_path) diff --git a/lite/tools/cmake_tools/parse_op_registry.py b/lite/tools/cmake_tools/parse_op_registry.py index 423036f6e84ffed39bb6d12589bbe354fcf8b883..6c936c899d1bd030cc7bf2c35bc8b1247608bfed 100644 --- a/lite/tools/cmake_tools/parse_op_registry.py +++ b/lite/tools/cmake_tools/parse_op_registry.py @@ -25,20 +25,24 @@ out_lines = [ '', ] +lines = set() with open(ops_list_path) as f: for line in f: - path = line.strip() + lines.add(line.strip()) - with open(path) as g: - for line in g: - key = 'REGISTER_LITE_OP' - if line.startswith(key): - end = line.find(',') - op = line[len(key) + 1:end] - if not op: continue - if "_grad" in op: continue - out = "USE_LITE_OP(%s);" % op - out_lines.append(out) +for line in lines: + path = line.strip() + + with open(path) as g: + for line in g: + key = 'REGISTER_LITE_OP' + if line.startswith(key): + end = line.find(',') + op = line[len(key) + 1:end] + if not op: continue + if "_grad" in op: continue + out = "USE_LITE_OP(%s);" % op + out_lines.append(out) with open(dest_path, 'w') as f: logging.info("write op list to %s" % dest_path)