diff --git a/mindinsight/mindconverter/ast_edits.py b/mindinsight/mindconverter/ast_edits.py index b96cc350c385551b008cb5571afbaac7a2d847cd..dc5a39138b8a41c67c70b2cba5f45c07fabed5fe 100644 --- a/mindinsight/mindconverter/ast_edits.py +++ b/mindinsight/mindconverter/ast_edits.py @@ -477,6 +477,31 @@ class AstEditVisitor(ast.NodeVisitor): return standard_api_call_name, match_case + @staticmethod + def _get_call_parameters_str(call_node): + """Get parameters string for a call node.""" + if not isinstance(call_node, ast.Call): + raise NodeTypeNotSupport('It is not ast.Call node type.') + parameters_str = '' + call_str = pasta.dump(call_node) + call_name = pasta.dump(call_node.func) + last_parameter_str = '' + + if call_node.args: + last_parameter_str = pasta.dump(call_node.args[-1]) + if call_node.keywords: + last_parameter_str = pasta.dump(call_node.keywords[-1]) + if last_parameter_str: + left_parenthesis_pos = call_str.find(call_name) + len(call_name) + # call is like abc.call(a, b,), last parameter is b, + # but parameters string must have last ',' character after the last parameter b. + last_parameter_pos = call_str.rfind(last_parameter_str) + len(last_parameter_str) + right_parenthesis_pos = call_str.find(')', last_parameter_pos) + + # parameters start pos must skip '(' character for calling. + parameters_str = call_str[left_parenthesis_pos + 1:right_parenthesis_pos] + return parameters_str + def mapping_api(self, call_node, check_context=True): """ Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert. @@ -498,7 +523,7 @@ class AstEditVisitor(ast.NodeVisitor): return code # find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)" - args_str = code[len(api_call_name):].strip() + args_str = '(' + self._get_call_parameters_str(call_node) + ')' try: api_name, _ = self._infer_api_name(call_node.func, check_context) diff --git a/mindinsight/mindconverter/code_analysis.py b/mindinsight/mindconverter/code_analysis.py index adbbbce6924b3a354f06e097847e27d4dc7b6309..8f1eca84648495528fe5ac33fa50f26b9c536208 100644 --- a/mindinsight/mindconverter/code_analysis.py +++ b/mindinsight/mindconverter/code_analysis.py @@ -193,13 +193,12 @@ class CodeAnalyzer(ast.NodeVisitor): for node_references in root_scope.external_references.values(): for node_ref in node_references: if node_ref.name_ref: - # (from)import alias, node_ref.name_ref.id is alias name - if node_ref.name_ref.definition.asname == node_ref.name_ref.id: - external_name_ref[node_ref.name_ref.id] = node_ref - # import without alias, node_ref.name_ref.definition.asname is None. - # e.g., import a.b.c, reference maybe is a, a.b or a.b.c in the root_scope.external_references. - # The reference a.b.c is really wanted. - elif node_ref.name_ref.definition.name == node_ref.name_ref.id: + # case1: (from)import alias, node_ref.name_ref.id is node_ref.name_ref.definition.asname. + # case2: import without alias, node_ref.name_ref.definition.asname is None. + # e.g., import a.b.c, the reference definition id maybe is a, a.b or a.b.c. + # The reference id a.b.c is really wanted. + if node_ref.name_ref.id in [node_ref.name_ref.definition.asname, + node_ref.name_ref.definition.name]: external_name_ref[node_ref.name_ref.id] = node_ref else: pass