未验证 提交 5a9befea 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Polish @to_static temporary file directory to speed up transformation (#47102) (#47144)

Polish @to_static temporary file directory to speed up transformation
上级 247ef477
...@@ -21,7 +21,8 @@ import copy ...@@ -21,7 +21,8 @@ import copy
import collections import collections
from paddle.utils import gast from paddle.utils import gast
import inspect import inspect
import os import os, sys
import shutil
import six import six
import tempfile import tempfile
import textwrap import textwrap
...@@ -84,6 +85,7 @@ dygraph_class_to_static_api = { ...@@ -84,6 +85,7 @@ dygraph_class_to_static_api = {
"PolynomialDecay": "polynomial_decay", "PolynomialDecay": "polynomial_decay",
} }
DEL_TEMP_DIR = True # A flag to avoid atexit.register more than once
FOR_ITER_INDEX_PREFIX = '__for_loop_var_index' FOR_ITER_INDEX_PREFIX = '__for_loop_var_index'
FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple' FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple'
FOR_ITER_TARGET_PREFIX = '__for_loop_iter_target' FOR_ITER_TARGET_PREFIX = '__for_loop_iter_target'
...@@ -548,6 +550,22 @@ def create_assign_node(name, node): ...@@ -548,6 +550,22 @@ def create_assign_node(name, node):
return targets, assign_node return targets, assign_node
def get_temp_dir():
"""
Return @to_static temp directory.
"""
dir_name = "paddle/to_static_tmp"
temp_dir = os.path.join(os.path.expanduser('~/.cache'), dir_name)
is_windows = sys.platform.startswith('win')
if is_windows:
temp_dir = os.path.normpath(temp_dir)
if not os.path.exists(temp_dir):
os.makedirs(temp_dir)
return temp_dir
def ast_to_func(ast_root, dyfunc, delete_on_exit=True): def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
""" """
Transform modified AST of decorated function into python callable object. Transform modified AST of decorated function into python callable object.
...@@ -555,27 +573,40 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): ...@@ -555,27 +573,40 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
function, the other inner functions are invisible for the decorated function. function, the other inner functions are invisible for the decorated function.
""" """
def remove_if_exit(filepath): def remove_if_exit(dir_path):
if os.path.exists(filepath): if os.path.exists(dir_path):
os.remove(filepath) shutil.rmtree(dir_path)
def func_prefix(func):
pre_fix = func.__name__
if hasattr(func, '__self__'):
try:
pre_fix = func.__self__.__class__.__name__ + '_' + func.__name__
except:
pass
return pre_fix
source = ast_to_source_code(ast_root) source = ast_to_source_code(ast_root)
source = _inject_import_statements() + source source = _inject_import_statements() + source
temp_dir = get_temp_dir()
f = tempfile.NamedTemporaryFile(mode='w', f = tempfile.NamedTemporaryFile(mode='w',
prefix=func_prefix(dyfunc),
suffix='.py', suffix='.py',
delete=False, delete=False,
dir=temp_dir,
encoding='utf-8') encoding='utf-8')
with f: with f:
module_name = os.path.basename(f.name[:-3]) module_name = os.path.basename(f.name[:-3])
f.write(source) f.write(source)
if delete_on_exit: global DEL_TEMP_DIR
atexit.register(lambda: remove_if_exit(f.name)) if delete_on_exit and DEL_TEMP_DIR:
atexit.register(lambda: remove_if_exit(f.name[:-3] + ".pyc")) # Clear temporary files in TEMP_DIR while exitting Python process
atexit.register(remove_if_exit, dir_path=temp_dir)
DEL_TEMP_DIR = False
module = SourceFileLoader(module_name, f.name).load_module()
func_name = dyfunc.__name__ func_name = dyfunc.__name__
module = SourceFileLoader(module_name, f.name).load_module()
# The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained # The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained
# through 'func_name'. So set the special function name '__i_m_p_l__'. # through 'func_name'. So set the special function name '__i_m_p_l__'.
if hasattr(module, '__i_m_p_l__'): if hasattr(module, '__i_m_p_l__'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册