提交 274bd253 编写于 作者: W wuyongkang

Optimize parser

上级 9ef744db
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
import ast import ast
import types import types
import inspect import inspect
import hashlib
from textwrap import dedent from textwrap import dedent
from dataclasses import is_dataclass from dataclasses import is_dataclass
import asttokens import asttokens
...@@ -319,7 +320,6 @@ def get_dataclass_methods(cls): ...@@ -319,7 +320,6 @@ def get_dataclass_methods(cls):
if isinstance(getattr(cls, name), (types.FunctionType,))} if isinstance(getattr(cls, name), (types.FunctionType,))}
return methods return methods
class Parser: class Parser:
""" """
Parser python code to ast tree. Parser python code to ast tree.
...@@ -327,7 +327,10 @@ class Parser: ...@@ -327,7 +327,10 @@ class Parser:
Args: Args:
fn(FunctionType/MethodType): Need parse object instance. fn(FunctionType/MethodType): Need parse object instance.
parse_method(ExtendInfoOfParseObj): Extend information for parse the function. parse_method(ExtendInfoOfParseObj): Extend information for parse the function.
ast_cache: Dictionary for caching ast tree.
""" """
ast_cache = {}
def __init__(self, fn: (types.FunctionType, types.MethodType), parse_method=None) -> None: def __init__(self, fn: (types.FunctionType, types.MethodType), parse_method=None) -> None:
self.fn = fn self.fn = fn
self.parse_method = parse_method self.parse_method = parse_method
...@@ -342,17 +345,31 @@ class Parser: ...@@ -342,17 +345,31 @@ class Parser:
self.function_name = fn.__name__ self.function_name = fn.__name__
self.col_offset = 0 self.col_offset = 0
@classmethod
def get_cache(cls, key):
"""Get the value of the ast_cache dictionary"""
return cls.ast_cache.get(key)
@classmethod
def insert_cache(cls, key, value):
"""Insert elements to the ast_cache dictionary"""
cls.ast_cache[key] = value
def parse(self): def parse(self):
"""Parse the function or method.""" """Parse the function or method."""
logger.debug("fn = %r", self.fn) logger.debug("fn = %r", self.fn)
tree = None tree = None
if isinstance(self.fn, (types.FunctionType, types.MethodType)): if isinstance(self.fn, (types.FunctionType, types.MethodType)):
original_src = inspect.getsource(self.fn) original_src = inspect.getsource(self.fn)
src = dedent(original_src) hexstr = hashlib.sha256(original_src.encode()).hexdigest()
self.col_offset = \ tree = Parser.get_cache(hexstr)
len(original_src.split('\n')[0]) - len(src.split('\n')[0]) if not tree:
logger.debug("get source = %s", src) src = dedent(original_src)
tree = asttokens.ASTTokens(src, parse=True).tree self.col_offset = \
len(original_src.split('\n')[0]) - len(src.split('\n')[0])
logger.debug("get source = %s", src)
tree = asttokens.ASTTokens(src, parse=True).tree
Parser.insert_cache(hexstr, tree)
else: else:
logger.error("Fn type is invalid") logger.error("Fn type is invalid")
return tree return tree
......
...@@ -94,7 +94,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object ...@@ -94,7 +94,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
auto top_graph = Parser::GetTopFuncGraph(); auto top_graph = Parser::GetTopFuncGraph();
// if the parameter node has been created , return it // if the parameter node has been created , return it
AnfNodePtr para_node = nullptr; AnfNodePtr para_node = nullptr;
for (auto param : top_graph->parameters()) { for (auto const &param : top_graph->parameters()) {
auto param_node = dyn_cast<Parameter>(param); auto param_node = dyn_cast<Parameter>(param);
if (param_node != nullptr && param_node->name() == param_name) { if (param_node != nullptr && param_node->name() == param_name) {
para_node = param; para_node = param;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册