未验证 提交 d265b16a 编写于 作者: A Aurelius84 提交者: GitHub

Support transform two independent functions in dygraph_to_static_output (#23652)

* Spport transform two independed function in dygraph_to_static_output test=develop

* fix unittest error test=develop
上级 b4be5ef5
...@@ -494,6 +494,8 @@ class NameVisitor(gast.NodeVisitor): ...@@ -494,6 +494,8 @@ class NameVisitor(gast.NodeVisitor):
self.generic_visit(node) self.generic_visit(node)
def visit_Name(self, node): def visit_Name(self, node):
blacklist = {'True', 'False', 'None'}
if node.id in blacklist: return
if not self._is_call_func_name_node(node): if not self._is_call_func_name_node(node):
if isinstance(node.ctx, self._candidate_ctxs): if isinstance(node.ctx, self._candidate_ctxs):
self.name_ids[node.id].append(node.ctx) self.name_ids[node.id].append(node.ctx)
......
...@@ -16,11 +16,9 @@ from __future__ import print_function ...@@ -16,11 +16,9 @@ from __future__ import print_function
import gast import gast
import inspect import inspect
import numpy import numpy
import six
import textwrap import textwrap
import threading import threading
import warnings import warnings
from collections import defaultdict
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid import core, executor from paddle.fluid import core, executor
...@@ -75,7 +73,7 @@ _FUNCTION_CACHE = FunctionCache() ...@@ -75,7 +73,7 @@ _FUNCTION_CACHE = FunctionCache()
def convert_function_with_cache(dygraph_func): def convert_function_with_cache(dygraph_func):
""" """
Transform function of dygraph into static function using the cache mechanism. Transforms function of dygraph into static function using the cache mechanism.
""" """
with _CACHE_LOCK: with _CACHE_LOCK:
static_func = _FUNCTION_CACHE.get_or_cache_func(dygraph_func) static_func = _FUNCTION_CACHE.get_or_cache_func(dygraph_func)
...@@ -106,9 +104,9 @@ class ProgramCache(object): ...@@ -106,9 +104,9 @@ class ProgramCache(object):
self._main_program = framework.default_main_program() self._main_program = framework.default_main_program()
self._startup_program = framework.default_startup_program() self._startup_program = framework.default_startup_program()
self._func_cache = FunctionCache() self._func_cache = FunctionCache()
self._feed_name_to_idx = {}
# Stores the entry function of Net or Model. # Stores the entry function of Net or Model.
self._forward_func = None self._forward_func = None
self._feed_name_to_idx = {}
self._is_repeated = False self._is_repeated = False
# Indicates whether the function call is still building program. # Indicates whether the function call is still building program.
# Because user can call recursively when `Net` has sub class in # Because user can call recursively when `Net` has sub class in
...@@ -117,10 +115,10 @@ class ProgramCache(object): ...@@ -117,10 +115,10 @@ class ProgramCache(object):
def build_program_and_return_output(self, dyfunc, *args, **kwargs): def build_program_and_return_output(self, dyfunc, *args, **kwargs):
""" """
Executes the main_program with specialized inputs so that the program Builds the main_program with specialized inputs and returns outputs
is built. This method also return outputs of program as fetch_list of program as fetch_list.
""" """
# Transfroms dygraph function into static functions and caches them. # Transforms dygraph function into static function and caches it.
static_func = self._transform_or_cache_layers(dyfunc) static_func = self._transform_or_cache_layers(dyfunc)
# 1. Adds `fluid.data` layers for input if needed # 1. Adds `fluid.data` layers for input if needed
...@@ -144,14 +142,22 @@ class ProgramCache(object): ...@@ -144,14 +142,22 @@ class ProgramCache(object):
Transforms dygraph function into static function. Transforms dygraph function into static function.
""" """
static_func = self._func_cache.get_or_cache_func(dyfunc) static_func = self._func_cache.get_or_cache_func(dyfunc)
if self._forward_func is None:
self._forward_func = static_func
else:
# self._forward_func is entry function of Net or Model. # self._forward_func is entry function of Net or Model.
# It can be called for multiple times, but layers from these functions # It can be called for multiple times, but layers from these functions
# call stack will be added into self._main_program only once. # call stack will be added into self._main_program only once.
# After that, cached program will be always returned by default. # After that, cached program will be always returned by default.
if static_func == self._forward_func: if static_func == self._forward_func:
self._is_repeated = True self._is_repeated = True
# If a independent function is received after the build process
if self._forward_func is None: # has finished, feed layers should be reset.
# TODO(Aurelius84): Switch main_program without specifying program_guard.
elif not self._in_build_process:
self._inputs = []
self._is_repeated = False
self._forward_func = static_func self._forward_func = static_func
return static_func return static_func
...@@ -180,7 +186,6 @@ class ProgramCache(object): ...@@ -180,7 +186,6 @@ class ProgramCache(object):
Adds `fluid.data` if the input `numpy.ndarray` is converted into `Variable` Adds `fluid.data` if the input `numpy.ndarray` is converted into `Variable`
by `to_variable()`, it makes program to be executed dynamically. by `to_variable()`, it makes program to be executed dynamically.
""" """
if not self._feed_name_to_idx:
self._feed_name_to_idx = self._get_name_to_idx(self._forward_func) self._feed_name_to_idx = self._get_name_to_idx(self._forward_func)
with framework.program_guard(self._main_program, self._startup_program): with framework.program_guard(self._main_program, self._startup_program):
for feed_name, idx in self.feed_name_to_idx.items(): for feed_name, idx in self.feed_name_to_idx.items():
...@@ -267,12 +272,12 @@ class ProgramTranslator(object): ...@@ -267,12 +272,12 @@ class ProgramTranslator(object):
self._optimizer_info = None self._optimizer_info = None
self._optimizer = None self._optimizer = None
self._loss_name = None self._loss_name = None
# Once main_program is changed, should run startup_program. # Once startup_program is changed, should run startup_program.
self._need_startup = True self._prev_startup = None
def get_output(self, dygraph_func, *args, **kwargs): def get_output(self, dygraph_func, *args, **kwargs):
""" """
Return the output tensors for dygraph function and its arguments Returns the output tensors for dygraph function and its arguments
""" """
if in_dygraph_mode(): if in_dygraph_mode():
warnings.warn( warnings.warn(
...@@ -292,7 +297,7 @@ class ProgramTranslator(object): ...@@ -292,7 +297,7 @@ class ProgramTranslator(object):
def get_func(self, dygraph_func): def get_func(self, dygraph_func):
""" """
Return the translated static function from dygraph function Returns the translated static function from dygraph function
""" """
if in_dygraph_mode(): if in_dygraph_mode():
warnings.warn( warnings.warn(
...@@ -305,7 +310,7 @@ class ProgramTranslator(object): ...@@ -305,7 +310,7 @@ class ProgramTranslator(object):
def get_program(self, dygraph_func, *args, **kwargs): def get_program(self, dygraph_func, *args, **kwargs):
""" """
Return the translated static program and input/output variables from Returns the translated static program and input/output variables from
dygraph function. dygraph function.
""" """
if in_dygraph_mode(): if in_dygraph_mode():
...@@ -321,9 +326,9 @@ class ProgramTranslator(object): ...@@ -321,9 +326,9 @@ class ProgramTranslator(object):
def get_code(self, dygraph_func): def get_code(self, dygraph_func):
""" """
Return the translated static function code from dygraph code Returns the translated static function code from dygraph code
""" """
# Get AST from dygraph function # Gets AST from dygraph function
raw_code = inspect.getsource(dygraph_func) raw_code = inspect.getsource(dygraph_func)
code = textwrap.dedent(raw_code) code = textwrap.dedent(raw_code)
root = gast.parse(code) root = gast.parse(code)
...@@ -338,7 +343,7 @@ class ProgramTranslator(object): ...@@ -338,7 +343,7 @@ class ProgramTranslator(object):
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
""" """
Execute main_program and returns output Tensors. Executes main_program and returns output Tensors.
""" """
feed_dict, fetch_list = self._prepare(args) feed_dict, fetch_list = self._prepare(args)
...@@ -351,7 +356,7 @@ class ProgramTranslator(object): ...@@ -351,7 +356,7 @@ class ProgramTranslator(object):
def set_optimizer(self, optimizer, index_of_loss=0): def set_optimizer(self, optimizer, index_of_loss=0):
""" """
Support to set or update the optimizer used to minimize loss. Supports to set or update the optimizer used to minimize loss.
""" """
check_type(index_of_loss, "index_of_loss", int, check_type(index_of_loss, "index_of_loss", int,
"ProgramTranslator.set_optimizer") "ProgramTranslator.set_optimizer")
...@@ -364,7 +369,7 @@ class ProgramTranslator(object): ...@@ -364,7 +369,7 @@ class ProgramTranslator(object):
def save_inference_model(self, dirname, feed=None, fetch=None): def save_inference_model(self, dirname, feed=None, fetch=None):
""" """
Save current model as the inference model. Saves current model as the inference model.
""" """
program_cache = self.get_program_cache() program_cache = self.get_program_cache()
if feed is None: if feed is None:
...@@ -383,27 +388,38 @@ class ProgramTranslator(object): ...@@ -383,27 +388,38 @@ class ProgramTranslator(object):
def _prepare(self, args): def _prepare(self, args):
""" """
Prepare with feed_dict, fetch_list, optimizer and initialize vars Prepares with feed_dict, fetch_list, optimizer and initialize vars
by running startup_program. by running startup_program.
""" """
# Update batch_data for feed_dict # Updates batch_data for feed_dict
feed_dict = self._update_batch_data(args) feed_dict = self._update_batch_data(args)
fetch_list = self._program_cache.outputs fetch_list = self._program_cache.outputs
# Add optimizer if needed. # Adds optimizer if needed.
if self._optimizer_info and self._optimizer is None: if self._optimizer_info and self._optimizer is None:
self._add_optimizer() self._add_optimizer()
if self._need_startup: if self._need_startup():
self._exe.run(self.startup_program) self._exe.run(self.startup_program)
self._need_startup = False self._prev_startup = self.startup_program
return feed_dict, fetch_list return feed_dict, fetch_list
def _need_startup(self):
"""
Determines whether needy to run startup_program.
"""
if self.startup_program != self._prev_startup:
check_type(self.startup_program, "startup_program",
framework.Program, "_need_startup")
return len(self.startup_program.global_block().ops) > 0
return False
def _check_cache_valid(self): def _check_cache_valid(self):
""" """
Check whether the current program is consistent with `default_main_program`. Checks whether the current program is consistent with `default_main_program`.
In some models and unittest, program will be switched frequently by `program_guard`. In some models and unittest, program will be switched frequently by `program_guard`.
If does, the cached program and other properties are not available and should be reset. If does, the cached program and other properties are not available and should be reset.
""" """
...@@ -414,7 +430,7 @@ class ProgramTranslator(object): ...@@ -414,7 +430,7 @@ class ProgramTranslator(object):
def _update_batch_data(self, args): def _update_batch_data(self, args):
""" """
Update cached batch data while training program. Updates cached batch data while training program.
""" """
feed_name_to_idx = self._program_cache.feed_name_to_idx feed_name_to_idx = self._program_cache.feed_name_to_idx
feed_vars = self._program_cache.inputs feed_vars = self._program_cache.inputs
...@@ -427,7 +443,7 @@ class ProgramTranslator(object): ...@@ -427,7 +443,7 @@ class ProgramTranslator(object):
def _add_optimizer(self): def _add_optimizer(self):
""" """
Support to set or update the optimizer used to minimize loss. Supports to set or update the optimizer used to minimize loss.
""" """
optimizer, index_of_loss = self._optimizer_info optimizer, index_of_loss = self._optimizer_info
...@@ -451,7 +467,7 @@ class ProgramTranslator(object): ...@@ -451,7 +467,7 @@ class ProgramTranslator(object):
raise ValueError( raise ValueError(
"Can't find {} in main_program, please confirm whether the input loss is correct." "Can't find {} in main_program, please confirm whether the input loss is correct."
.format(loss_var.name)) .format(loss_var.name))
# Add optimizer to minimize loss # Adds optimizer to minimize loss
with framework.program_guard(main_program, startup_program): with framework.program_guard(main_program, startup_program):
optimizer.minimize(loss_var) optimizer.minimize(loss_var)
...@@ -460,7 +476,7 @@ class ProgramTranslator(object): ...@@ -460,7 +476,7 @@ class ProgramTranslator(object):
def get_program_cache(self): def get_program_cache(self):
""" """
Return the ProgramCache instance. Returns the ProgramCache instance.
""" """
self._check_cache_valid() self._check_cache_valid()
return self._program_cache return self._program_cache
......
...@@ -20,6 +20,7 @@ from collections import Counter ...@@ -20,6 +20,7 @@ from collections import Counter
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import dygraph_to_static_output
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static import convert_function_with_cache from paddle.fluid.dygraph.dygraph_to_static import convert_function_with_cache
...@@ -138,5 +139,37 @@ class TestConvertWithCache(unittest.TestCase): ...@@ -138,5 +139,37 @@ class TestConvertWithCache(unittest.TestCase):
self.assertTrue(id(static_func), id(cached_func)) self.assertTrue(id(static_func), id(cached_func))
@dygraph_to_static_output
def sum_even_util_limit(max_len, limit):
ret_sum = fluid.dygraph.to_variable(np.zeros((1)).astype('int32'))
for i in range(max_len):
if i % 2 > 0:
continue
elif i > limit:
break
ret_sum += i
return ret_sum
@dygraph_to_static_output
def sum_under_while(limit):
i = fluid.dygraph.to_variable(np.zeros((1)).astype('int32'))
ret_sum = fluid.dygraph.to_variable(np.zeros((1)).astype('int32'))
while i <= limit:
ret_sum += i
i += 1
return ret_sum
class TestToOutputWithCache(unittest.TestCase):
def test_output(self):
ret = sum_even_util_limit(80, 10)
self.assertEqual(ret[0].numpy(), 30)
ret = sum_under_while(100)
self.assertEqual(ret[0].numpy(), 5050)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册