From 274bd25386a538f465487d2099b7adae52b6b477 Mon Sep 17 00:00:00 2001 From: wuyongkang Date: Mon, 29 Jun 2020 20:39:52 +0800 Subject: [PATCH] Optimize parser --- mindspore/_extends/parse/parser.py | 29 ++++++++++++++++++----- mindspore/ccsrc/pipeline/parse/resolve.cc | 2 +- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/mindspore/_extends/parse/parser.py b/mindspore/_extends/parse/parser.py index 2a1c9e094..30731316e 100644 --- a/mindspore/_extends/parse/parser.py +++ b/mindspore/_extends/parse/parser.py @@ -19,6 +19,7 @@ import ast import types import inspect +import hashlib from textwrap import dedent from dataclasses import is_dataclass import asttokens @@ -319,7 +320,6 @@ def get_dataclass_methods(cls): if isinstance(getattr(cls, name), (types.FunctionType,))} return methods - class Parser: """ Parser python code to ast tree. @@ -327,7 +327,10 @@ class Parser: Args: fn(FunctionType/MethodType): Need parse object instance. parse_method(ExtendInfoOfParseObj): Extend information for parse the function. + ast_cache: Dictionary for caching ast tree. """ + ast_cache = {} + def __init__(self, fn: (types.FunctionType, types.MethodType), parse_method=None) -> None: self.fn = fn self.parse_method = parse_method @@ -342,17 +345,31 @@ class Parser: self.function_name = fn.__name__ self.col_offset = 0 + @classmethod + def get_cache(cls, key): + """Get the value of the ast_cache dictionary""" + return cls.ast_cache.get(key) + + @classmethod + def insert_cache(cls, key, value): + """Insert elements to the ast_cache dictionary""" + cls.ast_cache[key] = value + def parse(self): """Parse the function or method.""" logger.debug("fn = %r", self.fn) tree = None if isinstance(self.fn, (types.FunctionType, types.MethodType)): original_src = inspect.getsource(self.fn) - src = dedent(original_src) - self.col_offset = \ - len(original_src.split('\n')[0]) - len(src.split('\n')[0]) - logger.debug("get source = %s", src) - tree = asttokens.ASTTokens(src, parse=True).tree + hexstr = hashlib.sha256(original_src.encode()).hexdigest() + tree = Parser.get_cache(hexstr) + if not tree: + src = dedent(original_src) + self.col_offset = \ + len(original_src.split('\n')[0]) - len(src.split('\n')[0]) + logger.debug("get source = %s", src) + tree = asttokens.ASTTokens(src, parse=True).tree + Parser.insert_cache(hexstr, tree) else: logger.error("Fn type is invalid") return tree diff --git a/mindspore/ccsrc/pipeline/parse/resolve.cc b/mindspore/ccsrc/pipeline/parse/resolve.cc index d5e1f828c..87c2f78b4 100644 --- a/mindspore/ccsrc/pipeline/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/parse/resolve.cc @@ -94,7 +94,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object auto top_graph = Parser::GetTopFuncGraph(); // if the parameter node has been created , return it AnfNodePtr para_node = nullptr; - for (auto param : top_graph->parameters()) { + for (auto const ¶m : top_graph->parameters()) { auto param_node = dyn_cast(param); if (param_node != nullptr && param_node->name() == param_name) { para_node = param; -- GitLab