From b3afac8a3e77f418deda645e955d17b0e93e238b Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 19 Oct 2022 10:56:51 +0800 Subject: [PATCH] [Dy2Stat]Polish @to_static temporary file directory to speed up transformation (#47102) * [Dy2Stat]Polish @to_static temporary file directory * [Dy2Stat]Polish @to_static temporary file directory * refine temp.name * fix typo * fix typo --- .../fluid/dygraph/dygraph_to_static/utils.py | 49 +++++++++++++++---- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index dbd5dc800d..849ff338ae 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -19,7 +19,8 @@ import copy import collections from paddle.utils import gast import inspect -import os +import os, sys +import shutil import six import tempfile import textwrap @@ -82,6 +83,7 @@ dygraph_class_to_static_api = { "PolynomialDecay": "polynomial_decay", } +DEL_TEMP_DIR = True # A flag to avoid atexit.register more than once FOR_ITER_INDEX_PREFIX = '__for_loop_var_index' FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple' FOR_ITER_TARGET_PREFIX = '__for_loop_iter_target' @@ -546,6 +548,22 @@ def create_assign_node(name, node): return targets, assign_node +def get_temp_dir(): + """ + Return @to_static temp directory. + """ + dir_name = "paddle/to_static_tmp" + temp_dir = os.path.join(os.path.expanduser('~/.cache'), dir_name) + is_windows = sys.platform.startswith('win') + if is_windows: + temp_dir = os.path.normpath(temp_dir) + + if not os.path.exists(temp_dir): + os.makedirs(temp_dir) + + return temp_dir + + def ast_to_func(ast_root, dyfunc, delete_on_exit=True): """ Transform modified AST of decorated function into python callable object. @@ -553,27 +571,40 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): function, the other inner functions are invisible for the decorated function. """ - def remove_if_exit(filepath): - if os.path.exists(filepath): - os.remove(filepath) + def remove_if_exit(dir_path): + if os.path.exists(dir_path): + shutil.rmtree(dir_path) + + def func_prefix(func): + pre_fix = func.__name__ + if hasattr(func, '__self__'): + try: + pre_fix = func.__self__.__class__.__name__ + '_' + func.__name__ + except: + pass + return pre_fix source = ast_to_source_code(ast_root) source = _inject_import_statements() + source - + temp_dir = get_temp_dir() f = tempfile.NamedTemporaryFile(mode='w', + prefix=func_prefix(dyfunc), suffix='.py', delete=False, + dir=temp_dir, encoding='utf-8') with f: module_name = os.path.basename(f.name[:-3]) f.write(source) - if delete_on_exit: - atexit.register(lambda: remove_if_exit(f.name)) - atexit.register(lambda: remove_if_exit(f.name[:-3] + ".pyc")) + global DEL_TEMP_DIR + if delete_on_exit and DEL_TEMP_DIR: + # Clear temporary files in TEMP_DIR while exitting Python process + atexit.register(remove_if_exit, dir_path=temp_dir) + DEL_TEMP_DIR = False - module = SourceFileLoader(module_name, f.name).load_module() func_name = dyfunc.__name__ + module = SourceFileLoader(module_name, f.name).load_module() # The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained # through 'func_name'. So set the special function name '__i_m_p_l__'. if hasattr(module, '__i_m_p_l__'): -- GitLab