未验证 提交 57ad9b46 编写于 作者: N Nyakku Shigure 提交者: GitHub

[Dy2St] replace deprecated `load_module` with `exec_module` (#48679)

上级 c838c1ed
...@@ -18,6 +18,7 @@ import atexit ...@@ -18,6 +18,7 @@ import atexit
import copy import copy
from paddle.utils import gast from paddle.utils import gast
import inspect import inspect
import importlib.util
import os import os
import sys import sys
import shutil import shutil
...@@ -32,6 +33,7 @@ from paddle.fluid import core ...@@ -32,6 +33,7 @@ from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers import assign from paddle.fluid.layers import assign
from functools import reduce from functools import reduce
from importlib.machinery import SourceFileLoader
import warnings import warnings
...@@ -71,9 +73,6 @@ class BaseNodeVisitor(gast.NodeVisitor): ...@@ -71,9 +73,6 @@ class BaseNodeVisitor(gast.NodeVisitor):
return ret return ret
# imp is deprecated in python3
from importlib.machinery import SourceFileLoader
dygraph_class_to_static_api = { dygraph_class_to_static_api = {
"CosineDecay": "cosine_decay", "CosineDecay": "cosine_decay",
"ExponentialDecay": "exponential_decay", "ExponentialDecay": "exponential_decay",
...@@ -586,7 +585,10 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): ...@@ -586,7 +585,10 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
DEL_TEMP_DIR = False DEL_TEMP_DIR = False
func_name = dyfunc.__name__ func_name = dyfunc.__name__
module = SourceFileLoader(module_name, f.name).load_module() loader = SourceFileLoader(module_name, f.name)
spec = importlib.util.spec_from_loader(loader.name, loader)
module = importlib.util.module_from_spec(spec)
loader.exec_module(module)
# The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained # The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained
# through 'func_name'. So set the special function name '__i_m_p_l__'. # through 'func_name'. So set the special function name '__i_m_p_l__'.
if hasattr(module, '__i_m_p_l__'): if hasattr(module, '__i_m_p_l__'):
......
...@@ -16,6 +16,7 @@ import atexit ...@@ -16,6 +16,7 @@ import atexit
import collections import collections
import glob import glob
import hashlib import hashlib
import importlib.util
import json import json
import logging import logging
import os import os
...@@ -1070,7 +1071,9 @@ def _load_module_from_file(api_file_path, module_name, verbose=False): ...@@ -1070,7 +1071,9 @@ def _load_module_from_file(api_file_path, module_name, verbose=False):
# load module with RWLock # load module with RWLock
loader = machinery.SourceFileLoader(ext_name, api_file_path) loader = machinery.SourceFileLoader(ext_name, api_file_path)
module = loader.load_module() spec = importlib.util.spec_from_loader(loader.name, loader)
module = importlib.util.module_from_spec(spec)
loader.exec_module(module)
return module return module
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册