提交 274bd253 编写于 作者: W wuyongkang

Optimize parser

上级 9ef744db
......@@ -19,6 +19,7 @@
import ast
import types
import inspect
import hashlib
from textwrap import dedent
from dataclasses import is_dataclass
import asttokens
......@@ -319,7 +320,6 @@ def get_dataclass_methods(cls):
if isinstance(getattr(cls, name), (types.FunctionType,))}
return methods
class Parser:
"""
Parser python code to ast tree.
......@@ -327,7 +327,10 @@ class Parser:
Args:
fn(FunctionType/MethodType): Need parse object instance.
parse_method(ExtendInfoOfParseObj): Extend information for parse the function.
ast_cache: Dictionary for caching ast tree.
"""
ast_cache = {}
def __init__(self, fn: (types.FunctionType, types.MethodType), parse_method=None) -> None:
self.fn = fn
self.parse_method = parse_method
......@@ -342,17 +345,31 @@ class Parser:
self.function_name = fn.__name__
self.col_offset = 0
@classmethod
def get_cache(cls, key):
"""Get the value of the ast_cache dictionary"""
return cls.ast_cache.get(key)
@classmethod
def insert_cache(cls, key, value):
"""Insert elements to the ast_cache dictionary"""
cls.ast_cache[key] = value
def parse(self):
"""Parse the function or method."""
logger.debug("fn = %r", self.fn)
tree = None
if isinstance(self.fn, (types.FunctionType, types.MethodType)):
original_src = inspect.getsource(self.fn)
src = dedent(original_src)
self.col_offset = \
len(original_src.split('\n')[0]) - len(src.split('\n')[0])
logger.debug("get source = %s", src)
tree = asttokens.ASTTokens(src, parse=True).tree
hexstr = hashlib.sha256(original_src.encode()).hexdigest()
tree = Parser.get_cache(hexstr)
if not tree:
src = dedent(original_src)
self.col_offset = \
len(original_src.split('\n')[0]) - len(src.split('\n')[0])
logger.debug("get source = %s", src)
tree = asttokens.ASTTokens(src, parse=True).tree
Parser.insert_cache(hexstr, tree)
else:
logger.error("Fn type is invalid")
return tree
......
......@@ -94,7 +94,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
auto top_graph = Parser::GetTopFuncGraph();
// if the parameter node has been created , return it
AnfNodePtr para_node = nullptr;
for (auto param : top_graph->parameters()) {
for (auto const &param : top_graph->parameters()) {
auto param_node = dyn_cast<Parameter>(param);
if (param_node != nullptr && param_node->name() == param_name) {
para_node = param;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册