From 57ad9b46a93f734633464d21ab259033ba82afe0 Mon Sep 17 00:00:00 2001 From: Nyakku Shigure Date: Wed, 7 Dec 2022 09:53:25 +0800 Subject: [PATCH] [Dy2St] replace deprecated `load_module` with `exec_module` (#48679) --- python/paddle/jit/dy2static/utils.py | 10 ++++++---- python/paddle/utils/cpp_extension/extension_utils.py | 5 ++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index a57134411f..438baef376 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -18,6 +18,7 @@ import atexit import copy from paddle.utils import gast import inspect +import importlib.util import os import sys import shutil @@ -32,6 +33,7 @@ from paddle.fluid import core from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layers import assign from functools import reduce +from importlib.machinery import SourceFileLoader import warnings @@ -71,9 +73,6 @@ class BaseNodeVisitor(gast.NodeVisitor): return ret -# imp is deprecated in python3 -from importlib.machinery import SourceFileLoader - dygraph_class_to_static_api = { "CosineDecay": "cosine_decay", "ExponentialDecay": "exponential_decay", @@ -586,7 +585,10 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): DEL_TEMP_DIR = False func_name = dyfunc.__name__ - module = SourceFileLoader(module_name, f.name).load_module() + loader = SourceFileLoader(module_name, f.name) + spec = importlib.util.spec_from_loader(loader.name, loader) + module = importlib.util.module_from_spec(spec) + loader.exec_module(module) # The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained # through 'func_name'. So set the special function name '__i_m_p_l__'. if hasattr(module, '__i_m_p_l__'): diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index 09b5492e54..29a4deeb1c 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -16,6 +16,7 @@ import atexit import collections import glob import hashlib +import importlib.util import json import logging import os @@ -1070,7 +1071,9 @@ def _load_module_from_file(api_file_path, module_name, verbose=False): # load module with RWLock loader = machinery.SourceFileLoader(ext_name, api_file_path) - module = loader.load_module() + spec = importlib.util.spec_from_loader(loader.name, loader) + module = importlib.util.module_from_spec(spec) + loader.exec_module(module) return module -- GitLab