提交 413db281 编写于 作者: S sangoly 提交者: Yan Chunwei

[API Generate] fix generate api bug test=develop (#1872)

上级 3bbeffa7
...@@ -24,36 +24,56 @@ out_lines = [ ...@@ -24,36 +24,56 @@ out_lines = [
'', '',
] ]
left_pattern = 'REGISTER_LITE_KERNEL('
right_pattern = ')'
def find_right_pattern(context, start):
if start >= len(context): return -1
fake_left_num = 0
while start < len(context):
if context[start] == right_pattern:
if fake_left_num == 0:
return start
else:
fake_left_num -= 1
elif context[start] == '(':
fake_left_num += 1
start += 1
return -1
lines = set()
with open(ops_list_path) as f: with open(ops_list_path) as f:
for line in f: for line in f:
lines.add(line.strip())
for line in lines:
path = line.strip() path = line.strip()
status = '' status = ''
with open(path) as g: with open(path) as g:
lines = [v for v in g] context = ''.join([item.strip() for item in g])
for i in range(len(lines)): index = 0
line = lines[i].strip() cxt_len = len(context)
while index < cxt_len and index >= 0:
if not status: left_index = context.find(left_pattern, index)
key = 'REGISTER_LITE_KERNEL' if left_index < 0: break
if line.startswith(key): right_index = find_right_pattern(context, left_index+len(left_pattern))
forward = i + min(7, len(lines) - i) if right_index < 0:
remaining = line[len(key) + 1:] + ' '.join( raise ValueError("Left Pattern and Right Pattern does not match")
[v.strip() for v in lines[i + 1:forward]]) tmp = context[left_index+len(left_pattern) : right_index]
index = right_index + 1
x = remaining.find('.') if tmp.startswith('/'): continue
if x > 0: fields = [item.strip() for item in tmp.split(',')]
remaining = remaining[:x] if len(fields) < 6:
raise ValueError("Invalid REGISTER_LITE_KERNEL format")
fs = [v.strip() for v in remaining.split(',')]
assert (len(fs) >= 4)
op, target, precision, layout, __, alias = fs[:6]
alias = alias.replace(')', '')
op, target, precision, layout = fields[:4]
alias = fields[-1]
key = "USE_LITE_KERNEL(%s, %s, %s, %s, %s);" % ( key = "USE_LITE_KERNEL(%s, %s, %s, %s, %s);" % (
op, target, precision, layout, alias) op, target, precision, layout, alias)
out_lines.append(key) out_lines.append(key)
with open(dest_path, 'w') as f: with open(dest_path, 'w') as f:
logging.info("write kernel list to %s" % dest_path) logging.info("write kernel list to %s" % dest_path)
f.write('\n'.join(out_lines)) f.write('\n'.join(out_lines))
...@@ -25,8 +25,12 @@ out_lines = [ ...@@ -25,8 +25,12 @@ out_lines = [
'', '',
] ]
lines = set()
with open(ops_list_path) as f: with open(ops_list_path) as f:
for line in f: for line in f:
lines.add(line.strip())
for line in lines:
path = line.strip() path = line.strip()
with open(path) as g: with open(path) as g:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册