提交 413db281 编写于 作者: S sangoly 提交者: Yan Chunwei

[API Generate] fix generate api bug test=develop (#1872)

上级 3bbeffa7
......@@ -24,35 +24,55 @@ out_lines = [
'',
]
left_pattern = 'REGISTER_LITE_KERNEL('
right_pattern = ')'
def find_right_pattern(context, start):
if start >= len(context): return -1
fake_left_num = 0
while start < len(context):
if context[start] == right_pattern:
if fake_left_num == 0:
return start
else:
fake_left_num -= 1
elif context[start] == '(':
fake_left_num += 1
start += 1
return -1
lines = set()
with open(ops_list_path) as f:
for line in f:
path = line.strip()
status = ''
with open(path) as g:
lines = [v for v in g]
for i in range(len(lines)):
line = lines[i].strip()
if not status:
key = 'REGISTER_LITE_KERNEL'
if line.startswith(key):
forward = i + min(7, len(lines) - i)
remaining = line[len(key) + 1:] + ' '.join(
[v.strip() for v in lines[i + 1:forward]])
x = remaining.find('.')
if x > 0:
remaining = remaining[:x]
fs = [v.strip() for v in remaining.split(',')]
assert (len(fs) >= 4)
op, target, precision, layout, __, alias = fs[:6]
alias = alias.replace(')', '')
key = "USE_LITE_KERNEL(%s, %s, %s, %s, %s);" % (
op, target, precision, layout, alias)
out_lines.append(key)
lines.add(line.strip())
for line in lines:
path = line.strip()
status = ''
with open(path) as g:
context = ''.join([item.strip() for item in g])
index = 0
cxt_len = len(context)
while index < cxt_len and index >= 0:
left_index = context.find(left_pattern, index)
if left_index < 0: break
right_index = find_right_pattern(context, left_index+len(left_pattern))
if right_index < 0:
raise ValueError("Left Pattern and Right Pattern does not match")
tmp = context[left_index+len(left_pattern) : right_index]
index = right_index + 1
if tmp.startswith('/'): continue
fields = [item.strip() for item in tmp.split(',')]
if len(fields) < 6:
raise ValueError("Invalid REGISTER_LITE_KERNEL format")
op, target, precision, layout = fields[:4]
alias = fields[-1]
key = "USE_LITE_KERNEL(%s, %s, %s, %s, %s);" % (
op, target, precision, layout, alias)
out_lines.append(key)
with open(dest_path, 'w') as f:
logging.info("write kernel list to %s" % dest_path)
......
......@@ -25,20 +25,24 @@ out_lines = [
'',
]
lines = set()
with open(ops_list_path) as f:
for line in f:
path = line.strip()
lines.add(line.strip())
with open(path) as g:
for line in g:
key = 'REGISTER_LITE_OP'
if line.startswith(key):
end = line.find(',')
op = line[len(key) + 1:end]
if not op: continue
if "_grad" in op: continue
out = "USE_LITE_OP(%s);" % op
out_lines.append(out)
for line in lines:
path = line.strip()
with open(path) as g:
for line in g:
key = 'REGISTER_LITE_OP'
if line.startswith(key):
end = line.find(',')
op = line[len(key) + 1:end]
if not op: continue
if "_grad" in op: continue
out = "USE_LITE_OP(%s);" % op
out_lines.append(out)
with open(dest_path, 'w') as f:
logging.info("write op list to %s" % dest_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册