未验证 提交 b3afac8a 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Polish @to_static temporary file directory to speed up transformation (#47102)

* [Dy2Stat]Polish @to_static temporary file directory

* [Dy2Stat]Polish @to_static temporary file directory

* refine temp.name

* fix typo

* fix typo
上级 2814d7f6
......@@ -19,7 +19,8 @@ import copy
import collections
from paddle.utils import gast
import inspect
import os
import os, sys
import shutil
import six
import tempfile
import textwrap
......@@ -82,6 +83,7 @@ dygraph_class_to_static_api = {
"PolynomialDecay": "polynomial_decay",
}
DEL_TEMP_DIR = True # A flag to avoid atexit.register more than once
FOR_ITER_INDEX_PREFIX = '__for_loop_var_index'
FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple'
FOR_ITER_TARGET_PREFIX = '__for_loop_iter_target'
......@@ -546,6 +548,22 @@ def create_assign_node(name, node):
return targets, assign_node
def get_temp_dir():
"""
Return @to_static temp directory.
"""
dir_name = "paddle/to_static_tmp"
temp_dir = os.path.join(os.path.expanduser('~/.cache'), dir_name)
is_windows = sys.platform.startswith('win')
if is_windows:
temp_dir = os.path.normpath(temp_dir)
if not os.path.exists(temp_dir):
os.makedirs(temp_dir)
return temp_dir
def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
"""
Transform modified AST of decorated function into python callable object.
......@@ -553,27 +571,40 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
function, the other inner functions are invisible for the decorated function.
"""
def remove_if_exit(filepath):
if os.path.exists(filepath):
os.remove(filepath)
def remove_if_exit(dir_path):
if os.path.exists(dir_path):
shutil.rmtree(dir_path)
def func_prefix(func):
pre_fix = func.__name__
if hasattr(func, '__self__'):
try:
pre_fix = func.__self__.__class__.__name__ + '_' + func.__name__
except:
pass
return pre_fix
source = ast_to_source_code(ast_root)
source = _inject_import_statements() + source
temp_dir = get_temp_dir()
f = tempfile.NamedTemporaryFile(mode='w',
prefix=func_prefix(dyfunc),
suffix='.py',
delete=False,
dir=temp_dir,
encoding='utf-8')
with f:
module_name = os.path.basename(f.name[:-3])
f.write(source)
if delete_on_exit:
atexit.register(lambda: remove_if_exit(f.name))
atexit.register(lambda: remove_if_exit(f.name[:-3] + ".pyc"))
global DEL_TEMP_DIR
if delete_on_exit and DEL_TEMP_DIR:
# Clear temporary files in TEMP_DIR while exitting Python process
atexit.register(remove_if_exit, dir_path=temp_dir)
DEL_TEMP_DIR = False
module = SourceFileLoader(module_name, f.name).load_module()
func_name = dyfunc.__name__
module = SourceFileLoader(module_name, f.name).load_module()
# The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained
# through 'func_name'. So set the special function name '__i_m_p_l__'.
if hasattr(module, '__i_m_p_l__'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册