diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index 22c77dc854ed4ae5e137579968e9b5a341b1cdaf..b57d74f6470079a267968b6a4f384fc7c4d8c941 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -494,6 +494,8 @@ class NameVisitor(gast.NodeVisitor): self.generic_visit(node) def visit_Name(self, node): + blacklist = {'True', 'False', 'None'} + if node.id in blacklist: return if not self._is_call_func_name_node(node): if isinstance(node.ctx, self._candidate_ctxs): self.name_ids[node.id].append(node.ctx) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 62f2e50c0d46bc1d48a12dc68dc9be11d8a238bd..84a5adaf7c01a668b1857b01ce3587b5e9216bba 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -16,11 +16,9 @@ from __future__ import print_function import gast import inspect import numpy -import six import textwrap import threading import warnings -from collections import defaultdict from paddle.fluid import framework from paddle.fluid import core, executor @@ -75,7 +73,7 @@ _FUNCTION_CACHE = FunctionCache() def convert_function_with_cache(dygraph_func): """ - Transform function of dygraph into static function using the cache mechanism. + Transforms function of dygraph into static function using the cache mechanism. """ with _CACHE_LOCK: static_func = _FUNCTION_CACHE.get_or_cache_func(dygraph_func) @@ -106,9 +104,9 @@ class ProgramCache(object): self._main_program = framework.default_main_program() self._startup_program = framework.default_startup_program() self._func_cache = FunctionCache() + self._feed_name_to_idx = {} # Stores the entry function of Net or Model. self._forward_func = None - self._feed_name_to_idx = {} self._is_repeated = False # Indicates whether the function call is still building program. # Because user can call recursively when `Net` has sub class in @@ -117,10 +115,10 @@ class ProgramCache(object): def build_program_and_return_output(self, dyfunc, *args, **kwargs): """ - Executes the main_program with specialized inputs so that the program - is built. This method also return outputs of program as fetch_list + Builds the main_program with specialized inputs and returns outputs + of program as fetch_list. """ - # Transfroms dygraph function into static functions and caches them. + # Transforms dygraph function into static function and caches it. static_func = self._transform_or_cache_layers(dyfunc) # 1. Adds `fluid.data` layers for input if needed @@ -144,15 +142,23 @@ class ProgramCache(object): Transforms dygraph function into static function. """ static_func = self._func_cache.get_or_cache_func(dyfunc) - # self._forward_func is entry function of Net or Model. - # It can be called for multiple times, but layers from these functions - # call stack will be added into self._main_program only once. - # After that, cached program will be always returned by default. - if static_func == self._forward_func: - self._is_repeated = True if self._forward_func is None: self._forward_func = static_func + else: + # self._forward_func is entry function of Net or Model. + # It can be called for multiple times, but layers from these functions + # call stack will be added into self._main_program only once. + # After that, cached program will be always returned by default. + if static_func == self._forward_func: + self._is_repeated = True + # If a independent function is received after the build process + # has finished, feed layers should be reset. + # TODO(Aurelius84): Switch main_program without specifying program_guard. + elif not self._in_build_process: + self._inputs = [] + self._is_repeated = False + self._forward_func = static_func return static_func @@ -180,8 +186,7 @@ class ProgramCache(object): Adds `fluid.data` if the input `numpy.ndarray` is converted into `Variable` by `to_variable()`, it makes program to be executed dynamically. """ - if not self._feed_name_to_idx: - self._feed_name_to_idx = self._get_name_to_idx(self._forward_func) + self._feed_name_to_idx = self._get_name_to_idx(self._forward_func) with framework.program_guard(self._main_program, self._startup_program): for feed_name, idx in self.feed_name_to_idx.items(): batch_data = args[idx] @@ -267,12 +272,12 @@ class ProgramTranslator(object): self._optimizer_info = None self._optimizer = None self._loss_name = None - # Once main_program is changed, should run startup_program. - self._need_startup = True + # Once startup_program is changed, should run startup_program. + self._prev_startup = None def get_output(self, dygraph_func, *args, **kwargs): """ - Return the output tensors for dygraph function and its arguments + Returns the output tensors for dygraph function and its arguments """ if in_dygraph_mode(): warnings.warn( @@ -292,7 +297,7 @@ class ProgramTranslator(object): def get_func(self, dygraph_func): """ - Return the translated static function from dygraph function + Returns the translated static function from dygraph function """ if in_dygraph_mode(): warnings.warn( @@ -305,7 +310,7 @@ class ProgramTranslator(object): def get_program(self, dygraph_func, *args, **kwargs): """ - Return the translated static program and input/output variables from + Returns the translated static program and input/output variables from dygraph function. """ if in_dygraph_mode(): @@ -321,9 +326,9 @@ class ProgramTranslator(object): def get_code(self, dygraph_func): """ - Return the translated static function code from dygraph code + Returns the translated static function code from dygraph code """ - # Get AST from dygraph function + # Gets AST from dygraph function raw_code = inspect.getsource(dygraph_func) code = textwrap.dedent(raw_code) root = gast.parse(code) @@ -338,7 +343,7 @@ class ProgramTranslator(object): def run(self, *args, **kwargs): """ - Execute main_program and returns output Tensors. + Executes main_program and returns output Tensors. """ feed_dict, fetch_list = self._prepare(args) @@ -351,7 +356,7 @@ class ProgramTranslator(object): def set_optimizer(self, optimizer, index_of_loss=0): """ - Support to set or update the optimizer used to minimize loss. + Supports to set or update the optimizer used to minimize loss. """ check_type(index_of_loss, "index_of_loss", int, "ProgramTranslator.set_optimizer") @@ -364,7 +369,7 @@ class ProgramTranslator(object): def save_inference_model(self, dirname, feed=None, fetch=None): """ - Save current model as the inference model. + Saves current model as the inference model. """ program_cache = self.get_program_cache() if feed is None: @@ -383,27 +388,38 @@ class ProgramTranslator(object): def _prepare(self, args): """ - Prepare with feed_dict, fetch_list, optimizer and initialize vars + Prepares with feed_dict, fetch_list, optimizer and initialize vars by running startup_program. """ - # Update batch_data for feed_dict + # Updates batch_data for feed_dict feed_dict = self._update_batch_data(args) fetch_list = self._program_cache.outputs - # Add optimizer if needed. + # Adds optimizer if needed. if self._optimizer_info and self._optimizer is None: self._add_optimizer() - if self._need_startup: + if self._need_startup(): self._exe.run(self.startup_program) - self._need_startup = False + self._prev_startup = self.startup_program return feed_dict, fetch_list + def _need_startup(self): + """ + Determines whether needy to run startup_program. + """ + if self.startup_program != self._prev_startup: + check_type(self.startup_program, "startup_program", + framework.Program, "_need_startup") + return len(self.startup_program.global_block().ops) > 0 + + return False + def _check_cache_valid(self): """ - Check whether the current program is consistent with `default_main_program`. + Checks whether the current program is consistent with `default_main_program`. In some models and unittest, program will be switched frequently by `program_guard`. If does, the cached program and other properties are not available and should be reset. """ @@ -414,7 +430,7 @@ class ProgramTranslator(object): def _update_batch_data(self, args): """ - Update cached batch data while training program. + Updates cached batch data while training program. """ feed_name_to_idx = self._program_cache.feed_name_to_idx feed_vars = self._program_cache.inputs @@ -427,7 +443,7 @@ class ProgramTranslator(object): def _add_optimizer(self): """ - Support to set or update the optimizer used to minimize loss. + Supports to set or update the optimizer used to minimize loss. """ optimizer, index_of_loss = self._optimizer_info @@ -451,7 +467,7 @@ class ProgramTranslator(object): raise ValueError( "Can't find {} in main_program, please confirm whether the input loss is correct." .format(loss_var.name)) - # Add optimizer to minimize loss + # Adds optimizer to minimize loss with framework.program_guard(main_program, startup_program): optimizer.minimize(loss_var) @@ -460,7 +476,7 @@ class ProgramTranslator(object): def get_program_cache(self): """ - Return the ProgramCache instance. + Returns the ProgramCache instance. """ self._check_cache_valid() return self._program_cache diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py index 0e8b7d787a0f059f476d014353e0151d112b0e46..dbea3965db4de7165ef9cc6586ac5c950acc6367 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py @@ -20,6 +20,7 @@ from collections import Counter import paddle.fluid as fluid +from paddle.fluid.dygraph.jit import dygraph_to_static_output from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static import convert_function_with_cache @@ -138,5 +139,37 @@ class TestConvertWithCache(unittest.TestCase): self.assertTrue(id(static_func), id(cached_func)) +@dygraph_to_static_output +def sum_even_util_limit(max_len, limit): + ret_sum = fluid.dygraph.to_variable(np.zeros((1)).astype('int32')) + for i in range(max_len): + if i % 2 > 0: + continue + elif i > limit: + break + + ret_sum += i + return ret_sum + + +@dygraph_to_static_output +def sum_under_while(limit): + i = fluid.dygraph.to_variable(np.zeros((1)).astype('int32')) + ret_sum = fluid.dygraph.to_variable(np.zeros((1)).astype('int32')) + while i <= limit: + ret_sum += i + i += 1 + return ret_sum + + +class TestToOutputWithCache(unittest.TestCase): + def test_output(self): + ret = sum_even_util_limit(80, 10) + self.assertEqual(ret[0].numpy(), 30) + + ret = sum_under_while(100) + self.assertEqual(ret[0].numpy(), 5050) + + if __name__ == '__main__': unittest.main()