未验证 提交 d265b16a 编写于 作者: A Aurelius84 提交者: GitHub

Support transform two independent functions in dygraph_to_static_output (#23652)

* Spport transform two independed function in dygraph_to_static_output test=develop

* fix unittest error test=develop
上级 b4be5ef5
......@@ -494,6 +494,8 @@ class NameVisitor(gast.NodeVisitor):
self.generic_visit(node)
def visit_Name(self, node):
blacklist = {'True', 'False', 'None'}
if node.id in blacklist: return
if not self._is_call_func_name_node(node):
if isinstance(node.ctx, self._candidate_ctxs):
self.name_ids[node.id].append(node.ctx)
......
......@@ -16,11 +16,9 @@ from __future__ import print_function
import gast
import inspect
import numpy
import six
import textwrap
import threading
import warnings
from collections import defaultdict
from paddle.fluid import framework
from paddle.fluid import core, executor
......@@ -75,7 +73,7 @@ _FUNCTION_CACHE = FunctionCache()
def convert_function_with_cache(dygraph_func):
"""
Transform function of dygraph into static function using the cache mechanism.
Transforms function of dygraph into static function using the cache mechanism.
"""
with _CACHE_LOCK:
static_func = _FUNCTION_CACHE.get_or_cache_func(dygraph_func)
......@@ -106,9 +104,9 @@ class ProgramCache(object):
self._main_program = framework.default_main_program()
self._startup_program = framework.default_startup_program()
self._func_cache = FunctionCache()
self._feed_name_to_idx = {}
# Stores the entry function of Net or Model.
self._forward_func = None
self._feed_name_to_idx = {}
self._is_repeated = False
# Indicates whether the function call is still building program.
# Because user can call recursively when `Net` has sub class in
......@@ -117,10 +115,10 @@ class ProgramCache(object):
def build_program_and_return_output(self, dyfunc, *args, **kwargs):
"""
Executes the main_program with specialized inputs so that the program
is built. This method also return outputs of program as fetch_list
Builds the main_program with specialized inputs and returns outputs
of program as fetch_list.
"""
# Transfroms dygraph function into static functions and caches them.
# Transforms dygraph function into static function and caches it.
static_func = self._transform_or_cache_layers(dyfunc)
# 1. Adds `fluid.data` layers for input if needed
......@@ -144,14 +142,22 @@ class ProgramCache(object):
Transforms dygraph function into static function.
"""
static_func = self._func_cache.get_or_cache_func(dyfunc)
if self._forward_func is None:
self._forward_func = static_func
else:
# self._forward_func is entry function of Net or Model.
# It can be called for multiple times, but layers from these functions
# call stack will be added into self._main_program only once.
# After that, cached program will be always returned by default.
if static_func == self._forward_func:
self._is_repeated = True
if self._forward_func is None:
# If a independent function is received after the build process
# has finished, feed layers should be reset.
# TODO(Aurelius84): Switch main_program without specifying program_guard.
elif not self._in_build_process:
self._inputs = []
self._is_repeated = False
self._forward_func = static_func
return static_func
......@@ -180,7 +186,6 @@ class ProgramCache(object):
Adds `fluid.data` if the input `numpy.ndarray` is converted into `Variable`
by `to_variable()`, it makes program to be executed dynamically.
"""
if not self._feed_name_to_idx:
self._feed_name_to_idx = self._get_name_to_idx(self._forward_func)
with framework.program_guard(self._main_program, self._startup_program):
for feed_name, idx in self.feed_name_to_idx.items():
......@@ -267,12 +272,12 @@ class ProgramTranslator(object):
self._optimizer_info = None
self._optimizer = None
self._loss_name = None
# Once main_program is changed, should run startup_program.
self._need_startup = True
# Once startup_program is changed, should run startup_program.
self._prev_startup = None
def get_output(self, dygraph_func, *args, **kwargs):
"""
Return the output tensors for dygraph function and its arguments
Returns the output tensors for dygraph function and its arguments
"""
if in_dygraph_mode():
warnings.warn(
......@@ -292,7 +297,7 @@ class ProgramTranslator(object):
def get_func(self, dygraph_func):
"""
Return the translated static function from dygraph function
Returns the translated static function from dygraph function
"""
if in_dygraph_mode():
warnings.warn(
......@@ -305,7 +310,7 @@ class ProgramTranslator(object):
def get_program(self, dygraph_func, *args, **kwargs):
"""
Return the translated static program and input/output variables from
Returns the translated static program and input/output variables from
dygraph function.
"""
if in_dygraph_mode():
......@@ -321,9 +326,9 @@ class ProgramTranslator(object):
def get_code(self, dygraph_func):
"""
Return the translated static function code from dygraph code
Returns the translated static function code from dygraph code
"""
# Get AST from dygraph function
# Gets AST from dygraph function
raw_code = inspect.getsource(dygraph_func)
code = textwrap.dedent(raw_code)
root = gast.parse(code)
......@@ -338,7 +343,7 @@ class ProgramTranslator(object):
def run(self, *args, **kwargs):
"""
Execute main_program and returns output Tensors.
Executes main_program and returns output Tensors.
"""
feed_dict, fetch_list = self._prepare(args)
......@@ -351,7 +356,7 @@ class ProgramTranslator(object):
def set_optimizer(self, optimizer, index_of_loss=0):
"""
Support to set or update the optimizer used to minimize loss.
Supports to set or update the optimizer used to minimize loss.
"""
check_type(index_of_loss, "index_of_loss", int,
"ProgramTranslator.set_optimizer")
......@@ -364,7 +369,7 @@ class ProgramTranslator(object):
def save_inference_model(self, dirname, feed=None, fetch=None):
"""
Save current model as the inference model.
Saves current model as the inference model.
"""
program_cache = self.get_program_cache()
if feed is None:
......@@ -383,27 +388,38 @@ class ProgramTranslator(object):
def _prepare(self, args):
"""
Prepare with feed_dict, fetch_list, optimizer and initialize vars
Prepares with feed_dict, fetch_list, optimizer and initialize vars
by running startup_program.
"""
# Update batch_data for feed_dict
# Updates batch_data for feed_dict
feed_dict = self._update_batch_data(args)
fetch_list = self._program_cache.outputs
# Add optimizer if needed.
# Adds optimizer if needed.
if self._optimizer_info and self._optimizer is None:
self._add_optimizer()
if self._need_startup:
if self._need_startup():
self._exe.run(self.startup_program)
self._need_startup = False
self._prev_startup = self.startup_program
return feed_dict, fetch_list
def _need_startup(self):
"""
Determines whether needy to run startup_program.
"""
if self.startup_program != self._prev_startup:
check_type(self.startup_program, "startup_program",
framework.Program, "_need_startup")
return len(self.startup_program.global_block().ops) > 0
return False
def _check_cache_valid(self):
"""
Check whether the current program is consistent with `default_main_program`.
Checks whether the current program is consistent with `default_main_program`.
In some models and unittest, program will be switched frequently by `program_guard`.
If does, the cached program and other properties are not available and should be reset.
"""
......@@ -414,7 +430,7 @@ class ProgramTranslator(object):
def _update_batch_data(self, args):
"""
Update cached batch data while training program.
Updates cached batch data while training program.
"""
feed_name_to_idx = self._program_cache.feed_name_to_idx
feed_vars = self._program_cache.inputs
......@@ -427,7 +443,7 @@ class ProgramTranslator(object):
def _add_optimizer(self):
"""
Support to set or update the optimizer used to minimize loss.
Supports to set or update the optimizer used to minimize loss.
"""
optimizer, index_of_loss = self._optimizer_info
......@@ -451,7 +467,7 @@ class ProgramTranslator(object):
raise ValueError(
"Can't find {} in main_program, please confirm whether the input loss is correct."
.format(loss_var.name))
# Add optimizer to minimize loss
# Adds optimizer to minimize loss
with framework.program_guard(main_program, startup_program):
optimizer.minimize(loss_var)
......@@ -460,7 +476,7 @@ class ProgramTranslator(object):
def get_program_cache(self):
"""
Return the ProgramCache instance.
Returns the ProgramCache instance.
"""
self._check_cache_valid()
return self._program_cache
......
......@@ -20,6 +20,7 @@ from collections import Counter
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import dygraph_to_static_output
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static import convert_function_with_cache
......@@ -138,5 +139,37 @@ class TestConvertWithCache(unittest.TestCase):
self.assertTrue(id(static_func), id(cached_func))
@dygraph_to_static_output
def sum_even_util_limit(max_len, limit):
ret_sum = fluid.dygraph.to_variable(np.zeros((1)).astype('int32'))
for i in range(max_len):
if i % 2 > 0:
continue
elif i > limit:
break
ret_sum += i
return ret_sum
@dygraph_to_static_output
def sum_under_while(limit):
i = fluid.dygraph.to_variable(np.zeros((1)).astype('int32'))
ret_sum = fluid.dygraph.to_variable(np.zeros((1)).astype('int32'))
while i <= limit:
ret_sum += i
i += 1
return ret_sum
class TestToOutputWithCache(unittest.TestCase):
def test_output(self):
ret = sum_even_util_limit(80, 10)
self.assertEqual(ret[0].numpy(), 30)
ret = sum_under_while(100)
self.assertEqual(ret[0].numpy(), 5050)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册