提交 d85fcc1f 编写于 作者: S sangoly 提交者: Yan Chunwei

[API Generate] fix generate api bug test=develop (#1872)

上级 e936f8c1
...@@ -24,35 +24,55 @@ out_lines = [ ...@@ -24,35 +24,55 @@ out_lines = [
'', '',
] ]
left_pattern = 'REGISTER_LITE_KERNEL('
right_pattern = ')'
def find_right_pattern(context, start):
if start >= len(context): return -1
fake_left_num = 0
while start < len(context):
if context[start] == right_pattern:
if fake_left_num == 0:
return start
else:
fake_left_num -= 1
elif context[start] == '(':
fake_left_num += 1
start += 1
return -1
lines = set()
with open(ops_list_path) as f: with open(ops_list_path) as f:
for line in f: for line in f:
path = line.strip() lines.add(line.strip())
status = '' for line in lines:
with open(path) as g: path = line.strip()
lines = [v for v in g]
for i in range(len(lines)): status = ''
line = lines[i].strip() with open(path) as g:
context = ''.join([item.strip() for item in g])
if not status: index = 0
key = 'REGISTER_LITE_KERNEL' cxt_len = len(context)
if line.startswith(key): while index < cxt_len and index >= 0:
forward = i + min(7, len(lines) - i) left_index = context.find(left_pattern, index)
remaining = line[len(key) + 1:] + ' '.join( if left_index < 0: break
[v.strip() for v in lines[i + 1:forward]]) right_index = find_right_pattern(context, left_index+len(left_pattern))
if right_index < 0:
x = remaining.find('.') raise ValueError("Left Pattern and Right Pattern does not match")
if x > 0: tmp = context[left_index+len(left_pattern) : right_index]
remaining = remaining[:x] index = right_index + 1
if tmp.startswith('/'): continue
fs = [v.strip() for v in remaining.split(',')] fields = [item.strip() for item in tmp.split(',')]
assert (len(fs) >= 4) if len(fields) < 6:
op, target, precision, layout, __, alias = fs[:6] raise ValueError("Invalid REGISTER_LITE_KERNEL format")
alias = alias.replace(')', '')
op, target, precision, layout = fields[:4]
key = "USE_LITE_KERNEL(%s, %s, %s, %s, %s);" % ( alias = fields[-1]
op, target, precision, layout, alias) key = "USE_LITE_KERNEL(%s, %s, %s, %s, %s);" % (
out_lines.append(key) op, target, precision, layout, alias)
out_lines.append(key)
with open(dest_path, 'w') as f: with open(dest_path, 'w') as f:
logging.info("write kernel list to %s" % dest_path) logging.info("write kernel list to %s" % dest_path)
......
...@@ -25,20 +25,24 @@ out_lines = [ ...@@ -25,20 +25,24 @@ out_lines = [
'', '',
] ]
lines = set()
with open(ops_list_path) as f: with open(ops_list_path) as f:
for line in f: for line in f:
path = line.strip() lines.add(line.strip())
with open(path) as g: for line in lines:
for line in g: path = line.strip()
key = 'REGISTER_LITE_OP'
if line.startswith(key): with open(path) as g:
end = line.find(',') for line in g:
op = line[len(key) + 1:end] key = 'REGISTER_LITE_OP'
if not op: continue if line.startswith(key):
if "_grad" in op: continue end = line.find(',')
out = "USE_LITE_OP(%s);" % op op = line[len(key) + 1:end]
out_lines.append(out) if not op: continue
if "_grad" in op: continue
out = "USE_LITE_OP(%s);" % op
out_lines.append(out)
with open(dest_path, 'w') as f: with open(dest_path, 'w') as f:
logging.info("write op list to %s" % dest_path) logging.info("write op list to %s" % dest_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册