未验证 提交 6fc6bb3e 编写于 作者: H Huihuang Zheng 提交者: GitHub

Refactoring Program Cache Related Code (#23118)

1. Rename AutoTracer to ProgramTranslator
2. Rename cached_program to program_cache
3. Remove some functor style __call__
4. Dict key should be string but not hash code of string
上级 4db03190
......@@ -26,12 +26,12 @@ from .loop_transformer import *
from . import variable_trans_func
from .variable_trans_func import *
from . import cache_program
from .cache_program import *
from . import program_translator
from .program_translator import *
__all__ = []
__all__ += ast_transformer.__all__
__all__ += loop_transformer.__all__
__all__ += static_analysis.__all__
__all__ += variable_trans_func.__all__
__all__ += cache_program.__all__
__all__ += program_translator.__all__
......@@ -24,7 +24,7 @@ from paddle.fluid import core, executor
from paddle.fluid.data import data
from paddle.fluid.dygraph.dygraph_to_static import convert_to_static
__all__ = ['AutoTracer']
__all__ = ['ProgramTranslator']
class FunctionCache(object):
......@@ -33,36 +33,32 @@ class FunctionCache(object):
"""
def __init__(self):
self._cache_funcs = dict()
self._func_to_transformer = dict()
self._dycode_to_static_func = dict()
self._static_func_to_transformer = dict()
def __call__(self, func):
static_func = self._get_or_cache_func(func)
return static_func
def _get_or_cache_func(self, func):
cache_key = self.hash_key(func)
static_func = self._cache_funcs.get(cache_key, None)
def get_or_cache_func(self, func):
code = self._get_dedent_code_string(func)
static_func = self._dycode_to_static_func.get(code, None)
if static_func is None:
static_func, dygraph_to_static = convert_to_static(func)
self._cache_funcs[cache_key] = static_func
self._func_to_transformer[static_func] = dygraph_to_static
static_func, dygraph_to_static_transformer = convert_to_static(func)
self._dycode_to_static_func[code] = static_func
self._static_func_to_transformer[
static_func] = dygraph_to_static_transformer
return static_func
def transformer(self, func):
return self._func_to_transformer.get(func, None)
def get_transformer(self, func):
return self._static_func_to_transformer.get(func, None)
def hash_key(self, func):
def _get_dedent_code_string(self, func):
raw_code = inspect.getsource(func)
code = textwrap.dedent(raw_code)
return hash(code)
dedent_code = textwrap.dedent(raw_code)
return dedent_code
def exist(self, func):
return self._cache_funcs.get(self.hash_key(func), None) is not None
return self._dycode_to_static_func.get(
self._get_dedent_code_string(func), None) is not None
def synchronized(func):
......@@ -97,9 +93,10 @@ class ProgramCache(object):
# sub class in `forward()`.
self._in_build_process = True
def __call__(self, dyfunc, *args, **kwargs):
def build_program_and_return_output(self, dyfunc, *args, **kwargs):
"""
Executes the main_program with specialized inputs.
Executes the main_program with specialized inputs so that the program
is built. This method also return outputs of program as fetch_list
"""
# Transfroms dygraph function into static functions and caches them.
static_func = self._transform_or_cache_layers(dyfunc)
......@@ -124,7 +121,7 @@ class ProgramCache(object):
"""
Transforms dygraph function into static function.
"""
static_func = self._func_cache(dyfunc)
static_func = self._func_cache.get_or_cache_func(dyfunc)
# self._forward_func is entry function of Net or Model.
# It can be called for multiple times, but layers from these functions
# call stack will be added into self._program only once.
......@@ -181,7 +178,7 @@ class ProgramCache(object):
Returns name and index of input args from `forward(args)`
that need to be replaced with `fluid.data`.
"""
transformer = self._func_cache.transformer(func)
transformer = self._func_cache.get_transformer(func)
feed_name_to_idx = transformer.get_feed_name_to_idx()
return feed_name_to_idx
......@@ -206,7 +203,7 @@ class ProgramCache(object):
return self._in_build_process
class AutoTracer(object):
class ProgramTranslator(object):
_instance = None
......@@ -214,32 +211,32 @@ class AutoTracer(object):
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = object.__new__(cls, *args, **kwargs)
cls._instance.__initialized = False
cls._instance._initialized = False
return cls._instance
@classmethod
def get_instance(cls):
if cls._instance is None:
raise ValueError("FuncProgram hasn\'t been created!")
raise ValueError("ProgramTranslator hasn\'t been created!")
return cls._instance
@classmethod
def reset(cls):
if cls._instance is not None:
cls._instance.__initialized = False
cls._instance._initialized = False
cls._instance.__init__()
def __init__(self, exe=None, place=None):
# To make sure that calls __init__ only once.
if self.__initialized:
if self._initialized:
return
self.__initialized = True
self._initialized = True
self._place = core.CPUPlace() if place is None else place
if exe is None:
self._exe = executor.Executor(self._place)
else:
self._exe = exe
self._cached_program = ProgramCache()
self._program_cache = ProgramCache()
self._optimizer = None
self._already_minimized = False
# Once main_program is changed, should run startup_program.
......@@ -251,7 +248,7 @@ class AutoTracer(object):
"""
feed_dict, fetch_list = self._prepare(args)
main_program = self._cached_program.program
main_program = self._program_cache.program
outputs = self._exe.run(main_program,
feed=feed_dict,
fetch_list=fetch_list)
......@@ -266,7 +263,7 @@ class AutoTracer(object):
# Updates batch_data for feed_dict
feed_dict = self._update_batch_data(args)
fetch_list = self._cached_program.outputs
fetch_list = self._program_cache.outputs
# Adds optimizer if needed.
if self._optimizer and not self._already_minimized:
......@@ -284,16 +281,16 @@ class AutoTracer(object):
In some models and unittest, program will be switched frequently by `program_guard`.
If does, the cached program and other properties are not available and should be reset.
"""
if self._cached_program.program:
if self._cached_program.program != framework.default_main_program():
AutoTracer.reset()
if self._program_cache.program:
if self._program_cache.program != framework.default_main_program():
ProgramTranslator.reset()
def _update_batch_data(self, args):
"""
Updates cached batch data while training program.
"""
feed_name_to_idx = self._cached_program.feed_name_to_idx
feed_vars = self._cached_program.inputs
feed_name_to_idx = self._program_cache.feed_name_to_idx
feed_vars = self._program_cache.inputs
feed_dict = {}
for feed_var in feed_vars:
idx = feed_name_to_idx[feed_var.name]
......@@ -318,7 +315,7 @@ class AutoTracer(object):
"""
Supports to set or update the optimizer used to minimize loss.
"""
main_program = self._cached_program.program
main_program = self._program_cache.program
all_vars = main_program.block(0).vars
loss_var = all_vars.get(self._loss_name, None)
......@@ -333,13 +330,13 @@ class AutoTracer(object):
# Avoids to set optimizer repeatedly.
self._already_minimized = True
def get_cached_program(self):
def get_program_cache(self):
"""
Returns the ProgramCache instance.
"""
self._check_cache_valid()
return self._cached_program
return self._program_cache
@property
def program(self):
return self._cached_program.program
return self._program_cache.program
......@@ -20,7 +20,7 @@ import warnings
from ..wrapped_decorator import wrap_decorator
from .base import program_desc_tracing_guard, switch_to_static_graph
from .dygraph_to_static import AutoTracer, convert_to_static
from .dygraph_to_static import ProgramTranslator, convert_to_static
from .layers import Layer
from paddle.fluid import core
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode
......@@ -68,9 +68,7 @@ dygraph_to_static_graph = wrap_decorator(_dygraph_to_static_graph_)
def _dygraph_to_static_output_(dygraph_func):
# Singleton object to cache main_program to avoid inserting ops repeatedly.
# TODO: Need a better class name
auto_tracer = AutoTracer()
program_translator = ProgramTranslator()
def __impl__(*args, **kwargs):
if in_dygraph_mode():
......@@ -79,12 +77,13 @@ def _dygraph_to_static_output_(dygraph_func):
" Please use it in static mode.")
return dygraph_func(*args, **kwargs)
cached_program = auto_tracer.get_cached_program()
outputs = cached_program(dygraph_func, *args, **kwargs)
program_cache = program_translator.get_program_cache()
outputs = program_cache.build_program_and_return_output(dygraph_func,
*args, **kwargs)
# Run program to fetch output Tensors once building successfully.
if not cached_program.in_build_process:
outputs = auto_tracer.run(*args, **kwargs)
if not program_cache.in_build_process:
outputs = program_translator.run(*args, **kwargs)
return outputs
......
......@@ -20,7 +20,7 @@ from collections import Counter
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import AutoTracer
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import dygraph_to_static_output
from test_fetch_feed import Pool2D, Linear
......@@ -77,8 +77,8 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
adam = fluid.optimizer.AdamOptimizer(learning_rate=0.001)
# set optimizer
# TODO: Need a better interfaces to set optimizer.
auto_tracer = AutoTracer()
auto_tracer.set_optimizer(adam, 'avg_loss')
program_translator = ProgramTranslator()
program_translator.set_optimizer(adam, 'avg_loss')
for batch_id in range(self.batch_num):
pred, avg_loss = static_net(self.data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册