未验证 提交 6fc6bb3e 编写于 作者: H Huihuang Zheng 提交者: GitHub

Refactoring Program Cache Related Code (#23118)

1. Rename AutoTracer to ProgramTranslator
2. Rename cached_program to program_cache
3. Remove some functor style __call__
4. Dict key should be string but not hash code of string
上级 4db03190
...@@ -26,12 +26,12 @@ from .loop_transformer import * ...@@ -26,12 +26,12 @@ from .loop_transformer import *
from . import variable_trans_func from . import variable_trans_func
from .variable_trans_func import * from .variable_trans_func import *
from . import cache_program from . import program_translator
from .cache_program import * from .program_translator import *
__all__ = [] __all__ = []
__all__ += ast_transformer.__all__ __all__ += ast_transformer.__all__
__all__ += loop_transformer.__all__ __all__ += loop_transformer.__all__
__all__ += static_analysis.__all__ __all__ += static_analysis.__all__
__all__ += variable_trans_func.__all__ __all__ += variable_trans_func.__all__
__all__ += cache_program.__all__ __all__ += program_translator.__all__
...@@ -24,7 +24,7 @@ from paddle.fluid import core, executor ...@@ -24,7 +24,7 @@ from paddle.fluid import core, executor
from paddle.fluid.data import data from paddle.fluid.data import data
from paddle.fluid.dygraph.dygraph_to_static import convert_to_static from paddle.fluid.dygraph.dygraph_to_static import convert_to_static
__all__ = ['AutoTracer'] __all__ = ['ProgramTranslator']
class FunctionCache(object): class FunctionCache(object):
...@@ -33,36 +33,32 @@ class FunctionCache(object): ...@@ -33,36 +33,32 @@ class FunctionCache(object):
""" """
def __init__(self): def __init__(self):
self._cache_funcs = dict() self._dycode_to_static_func = dict()
self._func_to_transformer = dict() self._static_func_to_transformer = dict()
def __call__(self, func): def get_or_cache_func(self, func):
static_func = self._get_or_cache_func(func) code = self._get_dedent_code_string(func)
return static_func static_func = self._dycode_to_static_func.get(code, None)
def _get_or_cache_func(self, func):
cache_key = self.hash_key(func)
static_func = self._cache_funcs.get(cache_key, None)
if static_func is None: if static_func is None:
static_func, dygraph_to_static = convert_to_static(func) static_func, dygraph_to_static_transformer = convert_to_static(func)
self._cache_funcs[cache_key] = static_func self._dycode_to_static_func[code] = static_func
self._func_to_transformer[static_func] = dygraph_to_static self._static_func_to_transformer[
static_func] = dygraph_to_static_transformer
return static_func return static_func
def transformer(self, func): def get_transformer(self, func):
return self._func_to_transformer.get(func, None) return self._static_func_to_transformer.get(func, None)
def hash_key(self, func): def _get_dedent_code_string(self, func):
raw_code = inspect.getsource(func) raw_code = inspect.getsource(func)
code = textwrap.dedent(raw_code) dedent_code = textwrap.dedent(raw_code)
return dedent_code
return hash(code)
def exist(self, func): def exist(self, func):
return self._cache_funcs.get(self.hash_key(func), None) is not None return self._dycode_to_static_func.get(
self._get_dedent_code_string(func), None) is not None
def synchronized(func): def synchronized(func):
...@@ -97,9 +93,10 @@ class ProgramCache(object): ...@@ -97,9 +93,10 @@ class ProgramCache(object):
# sub class in `forward()`. # sub class in `forward()`.
self._in_build_process = True self._in_build_process = True
def __call__(self, dyfunc, *args, **kwargs): def build_program_and_return_output(self, dyfunc, *args, **kwargs):
""" """
Executes the main_program with specialized inputs. Executes the main_program with specialized inputs so that the program
is built. This method also return outputs of program as fetch_list
""" """
# Transfroms dygraph function into static functions and caches them. # Transfroms dygraph function into static functions and caches them.
static_func = self._transform_or_cache_layers(dyfunc) static_func = self._transform_or_cache_layers(dyfunc)
...@@ -124,7 +121,7 @@ class ProgramCache(object): ...@@ -124,7 +121,7 @@ class ProgramCache(object):
""" """
Transforms dygraph function into static function. Transforms dygraph function into static function.
""" """
static_func = self._func_cache(dyfunc) static_func = self._func_cache.get_or_cache_func(dyfunc)
# self._forward_func is entry function of Net or Model. # self._forward_func is entry function of Net or Model.
# It can be called for multiple times, but layers from these functions # It can be called for multiple times, but layers from these functions
# call stack will be added into self._program only once. # call stack will be added into self._program only once.
...@@ -181,7 +178,7 @@ class ProgramCache(object): ...@@ -181,7 +178,7 @@ class ProgramCache(object):
Returns name and index of input args from `forward(args)` Returns name and index of input args from `forward(args)`
that need to be replaced with `fluid.data`. that need to be replaced with `fluid.data`.
""" """
transformer = self._func_cache.transformer(func) transformer = self._func_cache.get_transformer(func)
feed_name_to_idx = transformer.get_feed_name_to_idx() feed_name_to_idx = transformer.get_feed_name_to_idx()
return feed_name_to_idx return feed_name_to_idx
...@@ -206,7 +203,7 @@ class ProgramCache(object): ...@@ -206,7 +203,7 @@ class ProgramCache(object):
return self._in_build_process return self._in_build_process
class AutoTracer(object): class ProgramTranslator(object):
_instance = None _instance = None
...@@ -214,32 +211,32 @@ class AutoTracer(object): ...@@ -214,32 +211,32 @@ class AutoTracer(object):
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if cls._instance is None: if cls._instance is None:
cls._instance = object.__new__(cls, *args, **kwargs) cls._instance = object.__new__(cls, *args, **kwargs)
cls._instance.__initialized = False cls._instance._initialized = False
return cls._instance return cls._instance
@classmethod @classmethod
def get_instance(cls): def get_instance(cls):
if cls._instance is None: if cls._instance is None:
raise ValueError("FuncProgram hasn\'t been created!") raise ValueError("ProgramTranslator hasn\'t been created!")
return cls._instance return cls._instance
@classmethod @classmethod
def reset(cls): def reset(cls):
if cls._instance is not None: if cls._instance is not None:
cls._instance.__initialized = False cls._instance._initialized = False
cls._instance.__init__() cls._instance.__init__()
def __init__(self, exe=None, place=None): def __init__(self, exe=None, place=None):
# To make sure that calls __init__ only once. # To make sure that calls __init__ only once.
if self.__initialized: if self._initialized:
return return
self.__initialized = True self._initialized = True
self._place = core.CPUPlace() if place is None else place self._place = core.CPUPlace() if place is None else place
if exe is None: if exe is None:
self._exe = executor.Executor(self._place) self._exe = executor.Executor(self._place)
else: else:
self._exe = exe self._exe = exe
self._cached_program = ProgramCache() self._program_cache = ProgramCache()
self._optimizer = None self._optimizer = None
self._already_minimized = False self._already_minimized = False
# Once main_program is changed, should run startup_program. # Once main_program is changed, should run startup_program.
...@@ -251,7 +248,7 @@ class AutoTracer(object): ...@@ -251,7 +248,7 @@ class AutoTracer(object):
""" """
feed_dict, fetch_list = self._prepare(args) feed_dict, fetch_list = self._prepare(args)
main_program = self._cached_program.program main_program = self._program_cache.program
outputs = self._exe.run(main_program, outputs = self._exe.run(main_program,
feed=feed_dict, feed=feed_dict,
fetch_list=fetch_list) fetch_list=fetch_list)
...@@ -266,7 +263,7 @@ class AutoTracer(object): ...@@ -266,7 +263,7 @@ class AutoTracer(object):
# Updates batch_data for feed_dict # Updates batch_data for feed_dict
feed_dict = self._update_batch_data(args) feed_dict = self._update_batch_data(args)
fetch_list = self._cached_program.outputs fetch_list = self._program_cache.outputs
# Adds optimizer if needed. # Adds optimizer if needed.
if self._optimizer and not self._already_minimized: if self._optimizer and not self._already_minimized:
...@@ -284,16 +281,16 @@ class AutoTracer(object): ...@@ -284,16 +281,16 @@ class AutoTracer(object):
In some models and unittest, program will be switched frequently by `program_guard`. In some models and unittest, program will be switched frequently by `program_guard`.
If does, the cached program and other properties are not available and should be reset. If does, the cached program and other properties are not available and should be reset.
""" """
if self._cached_program.program: if self._program_cache.program:
if self._cached_program.program != framework.default_main_program(): if self._program_cache.program != framework.default_main_program():
AutoTracer.reset() ProgramTranslator.reset()
def _update_batch_data(self, args): def _update_batch_data(self, args):
""" """
Updates cached batch data while training program. Updates cached batch data while training program.
""" """
feed_name_to_idx = self._cached_program.feed_name_to_idx feed_name_to_idx = self._program_cache.feed_name_to_idx
feed_vars = self._cached_program.inputs feed_vars = self._program_cache.inputs
feed_dict = {} feed_dict = {}
for feed_var in feed_vars: for feed_var in feed_vars:
idx = feed_name_to_idx[feed_var.name] idx = feed_name_to_idx[feed_var.name]
...@@ -318,7 +315,7 @@ class AutoTracer(object): ...@@ -318,7 +315,7 @@ class AutoTracer(object):
""" """
Supports to set or update the optimizer used to minimize loss. Supports to set or update the optimizer used to minimize loss.
""" """
main_program = self._cached_program.program main_program = self._program_cache.program
all_vars = main_program.block(0).vars all_vars = main_program.block(0).vars
loss_var = all_vars.get(self._loss_name, None) loss_var = all_vars.get(self._loss_name, None)
...@@ -333,13 +330,13 @@ class AutoTracer(object): ...@@ -333,13 +330,13 @@ class AutoTracer(object):
# Avoids to set optimizer repeatedly. # Avoids to set optimizer repeatedly.
self._already_minimized = True self._already_minimized = True
def get_cached_program(self): def get_program_cache(self):
""" """
Returns the ProgramCache instance. Returns the ProgramCache instance.
""" """
self._check_cache_valid() self._check_cache_valid()
return self._cached_program return self._program_cache
@property @property
def program(self): def program(self):
return self._cached_program.program return self._program_cache.program
...@@ -20,7 +20,7 @@ import warnings ...@@ -20,7 +20,7 @@ import warnings
from ..wrapped_decorator import wrap_decorator from ..wrapped_decorator import wrap_decorator
from .base import program_desc_tracing_guard, switch_to_static_graph from .base import program_desc_tracing_guard, switch_to_static_graph
from .dygraph_to_static import AutoTracer, convert_to_static from .dygraph_to_static import ProgramTranslator, convert_to_static
from .layers import Layer from .layers import Layer
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode
...@@ -68,9 +68,7 @@ dygraph_to_static_graph = wrap_decorator(_dygraph_to_static_graph_) ...@@ -68,9 +68,7 @@ dygraph_to_static_graph = wrap_decorator(_dygraph_to_static_graph_)
def _dygraph_to_static_output_(dygraph_func): def _dygraph_to_static_output_(dygraph_func):
# Singleton object to cache main_program to avoid inserting ops repeatedly. program_translator = ProgramTranslator()
# TODO: Need a better class name
auto_tracer = AutoTracer()
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
if in_dygraph_mode(): if in_dygraph_mode():
...@@ -79,12 +77,13 @@ def _dygraph_to_static_output_(dygraph_func): ...@@ -79,12 +77,13 @@ def _dygraph_to_static_output_(dygraph_func):
" Please use it in static mode.") " Please use it in static mode.")
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)
cached_program = auto_tracer.get_cached_program() program_cache = program_translator.get_program_cache()
outputs = cached_program(dygraph_func, *args, **kwargs) outputs = program_cache.build_program_and_return_output(dygraph_func,
*args, **kwargs)
# Run program to fetch output Tensors once building successfully. # Run program to fetch output Tensors once building successfully.
if not cached_program.in_build_process: if not program_cache.in_build_process:
outputs = auto_tracer.run(*args, **kwargs) outputs = program_translator.run(*args, **kwargs)
return outputs return outputs
......
...@@ -20,7 +20,7 @@ from collections import Counter ...@@ -20,7 +20,7 @@ from collections import Counter
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import AutoTracer from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import dygraph_to_static_output from paddle.fluid.dygraph.jit import dygraph_to_static_output
from test_fetch_feed import Pool2D, Linear from test_fetch_feed import Pool2D, Linear
...@@ -77,8 +77,8 @@ class TestCacheProgramWithOptimizer(unittest.TestCase): ...@@ -77,8 +77,8 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
adam = fluid.optimizer.AdamOptimizer(learning_rate=0.001) adam = fluid.optimizer.AdamOptimizer(learning_rate=0.001)
# set optimizer # set optimizer
# TODO: Need a better interfaces to set optimizer. # TODO: Need a better interfaces to set optimizer.
auto_tracer = AutoTracer() program_translator = ProgramTranslator()
auto_tracer.set_optimizer(adam, 'avg_loss') program_translator.set_optimizer(adam, 'avg_loss')
for batch_id in range(self.batch_num): for batch_id in range(self.batch_num):
pred, avg_loss = static_net(self.data) pred, avg_loss = static_net(self.data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册