未验证 提交 351a7743 编写于 作者: J juncaipeng 提交者: GitHub

Modify parse_op_registry (#2239)

* modify parse_op_registry. When REGISTER_LITE_OP and op_name not in the same row, it also can obtain op_name, test=develop
上级 cbe7098e
...@@ -34,7 +34,7 @@ bool ConvOpLite::CheckShape() const { ...@@ -34,7 +34,7 @@ bool ConvOpLite::CheckShape() const {
CHECK_EQ_OR_FALSE(in_dims.size(), filter_dims.size()); CHECK_EQ_OR_FALSE(in_dims.size(), filter_dims.size());
CHECK_OR_FALSE(in_dims.size() - param_.strides.size() == 2U); CHECK_OR_FALSE(in_dims.size() - param_.strides.size() == 2U);
CHECK_EQ_OR_FALSE(param_.paddings.size(), param_.strides.size()); // CHECK_EQ_OR_FALSE(param_.paddings.size(), param_.strides.size());
// CHECK_EQ_OR_FALSE(in_dims[1], filter_dims[1] * param_.groups); // CHECK_EQ_OR_FALSE(in_dims[1], filter_dims[1] * param_.groups);
// CHECK_EQ_OR_FALSE(filter_dims[0] % param_.groups, 0); // CHECK_EQ_OR_FALSE(filter_dims[0] % param_.groups, 0);
......
...@@ -310,6 +310,43 @@ class RegisterLiteKernelParser(SyntaxParser): ...@@ -310,6 +310,43 @@ class RegisterLiteKernelParser(SyntaxParser):
break break
class RegisterLiteOpParser(SyntaxParser):
KEYWORD = 'REGISTER_LITE_OP'
def __init__(self, str):
super(RegisterLiteOpParser, self).__init__(str)
self.ops = []
def parse(self):
while self.cur_pos < len(self.str):
start = self.str.find(self.KEYWORD, self.cur_pos)
if start != -1:
#print 'str ', start, self.str[start-2: start]
if start != 0 and '/' in self.str[start-2: start]:
'''
skip commented code
'''
self.cur_pos = start + 1
continue
self.cur_pos = start
self.ops.append(self.__parse_register())
else:
break
return self.ops
def __parse_register(self):
self.eat_word()
assert self.token == self.KEYWORD
self.eat_spaces()
self.eat_left_parentheses()
self.eat_spaces()
self.eat_word()
return self.token
if __name__ == '__main__': if __name__ == '__main__':
with open('/home/chunwei/project2/Paddle-Lite/lite/kernels/arm/activation_compute.cc') as f: with open('/home/chunwei/project2/Paddle-Lite/lite/kernels/arm/activation_compute.cc') as f:
c = f.read() c = f.read()
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import sys import sys
import logging import logging
from ast import RegisterLiteOpParser
ops_list_path = sys.argv[1] ops_list_path = sys.argv[1]
dest_path = sys.argv[2] dest_path = sys.argv[2]
...@@ -25,24 +26,19 @@ out_lines = [ ...@@ -25,24 +26,19 @@ out_lines = [
'', '',
] ]
lines = set() paths = set()
with open(ops_list_path) as f: for line in open(ops_list_path):
for line in f: paths.add(line.strip())
lines.add(line.strip())
for line in lines: for path in paths:
path = line.strip() str_info = open(path.strip()).read()
op_parser = RegisterLiteOpParser(str_info)
with open(path) as g: ops = op_parser.parse()
for line in g: for op in ops:
key = 'REGISTER_LITE_OP' if "_grad" in op:
if line.startswith(key): continue
end = line.find(',') out = "USE_LITE_OP(%s);" % op
op = line[len(key) + 1:end] out_lines.append(out)
if not op: continue
if "_grad" in op: continue
out = "USE_LITE_OP(%s);" % op
out_lines.append(out)
with open(dest_path, 'w') as f: with open(dest_path, 'w') as f:
logging.info("write op list to %s" % dest_path) logging.info("write op list to %s" % dest_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册