未验证 提交 351a7743 编写于 作者: J juncaipeng 提交者: GitHub

Modify parse_op_registry (#2239)

* modify parse_op_registry. When REGISTER_LITE_OP and op_name not in the same row, it also can obtain op_name, test=develop
上级 cbe7098e
......@@ -34,7 +34,7 @@ bool ConvOpLite::CheckShape() const {
CHECK_EQ_OR_FALSE(in_dims.size(), filter_dims.size());
CHECK_OR_FALSE(in_dims.size() - param_.strides.size() == 2U);
CHECK_EQ_OR_FALSE(param_.paddings.size(), param_.strides.size());
// CHECK_EQ_OR_FALSE(param_.paddings.size(), param_.strides.size());
// CHECK_EQ_OR_FALSE(in_dims[1], filter_dims[1] * param_.groups);
// CHECK_EQ_OR_FALSE(filter_dims[0] % param_.groups, 0);
......
......@@ -310,6 +310,43 @@ class RegisterLiteKernelParser(SyntaxParser):
break
class RegisterLiteOpParser(SyntaxParser):
KEYWORD = 'REGISTER_LITE_OP'
def __init__(self, str):
super(RegisterLiteOpParser, self).__init__(str)
self.ops = []
def parse(self):
while self.cur_pos < len(self.str):
start = self.str.find(self.KEYWORD, self.cur_pos)
if start != -1:
#print 'str ', start, self.str[start-2: start]
if start != 0 and '/' in self.str[start-2: start]:
'''
skip commented code
'''
self.cur_pos = start + 1
continue
self.cur_pos = start
self.ops.append(self.__parse_register())
else:
break
return self.ops
def __parse_register(self):
self.eat_word()
assert self.token == self.KEYWORD
self.eat_spaces()
self.eat_left_parentheses()
self.eat_spaces()
self.eat_word()
return self.token
if __name__ == '__main__':
with open('/home/chunwei/project2/Paddle-Lite/lite/kernels/arm/activation_compute.cc') as f:
c = f.read()
......
......@@ -15,6 +15,7 @@
import sys
import logging
from ast import RegisterLiteOpParser
ops_list_path = sys.argv[1]
dest_path = sys.argv[2]
......@@ -25,24 +26,19 @@ out_lines = [
'',
]
lines = set()
with open(ops_list_path) as f:
for line in f:
lines.add(line.strip())
paths = set()
for line in open(ops_list_path):
paths.add(line.strip())
for line in lines:
path = line.strip()
with open(path) as g:
for line in g:
key = 'REGISTER_LITE_OP'
if line.startswith(key):
end = line.find(',')
op = line[len(key) + 1:end]
if not op: continue
if "_grad" in op: continue
out = "USE_LITE_OP(%s);" % op
out_lines.append(out)
for path in paths:
str_info = open(path.strip()).read()
op_parser = RegisterLiteOpParser(str_info)
ops = op_parser.parse()
for op in ops:
if "_grad" in op:
continue
out = "USE_LITE_OP(%s);" % op
out_lines.append(out)
with open(dest_path, 'w') as f:
logging.info("write op list to %s" % dest_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册