From d37cd740339f76d25d2aa26fb5e2a891b53ec7f9 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 9 Apr 2020 11:00:03 +0800 Subject: [PATCH] Polish set_optimizer Interface (#23588) --- .../dygraph_to_static/program_translator.py | 80 +++++++++++-------- .../dygraph_to_static/test_cache_program.py | 17 +++- .../dygraph_to_static/test_fetch_feed.py | 2 +- .../test_save_inference_model.py | 4 +- 4 files changed, 66 insertions(+), 37 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 792fa3e0ff..9682fceb08 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -20,6 +20,7 @@ import six import textwrap import threading import warnings +from collections import defaultdict from paddle.fluid import framework from paddle.fluid import core, executor @@ -28,6 +29,7 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStat from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.data_feeder import check_type __all__ = ['ProgramTranslator', 'convert_function_with_cache'] @@ -261,19 +263,20 @@ class ProgramTranslator(object): else: self._exe = exe self._program_cache = ProgramCache() + self._optimizer_info = None self._optimizer = None - self._already_minimized = False + self._loss_name = None # Once main_program is changed, should run startup_program. self._need_startup = True def get_output(self, dygraph_func, *args, **kwargs): """ - Returns the output tensors for dygraph function and its arguments + Return the output tensors for dygraph function and its arguments """ if in_dygraph_mode(): warnings.warn( "The ProgramTranslator.get_output doesn't work in dygraph " - "mode. We will just return dygraph output. Use the it in " + "mode. We will just return dygraph output. Use it in " "static mode if you would like to translate to static graph.") return dygraph_func(*args, **kwargs) @@ -286,12 +289,12 @@ class ProgramTranslator(object): def get_func(self, dygraph_func): """ - Returns the translated static function from dygraph function + Return the translated static function from dygraph function """ if in_dygraph_mode(): warnings.warn( "The ProgramTranslator.get_func doesn't work in dygraph " - "mode. We will just return dygraph function. Use the it in " + "mode. We will just return dygraph function. Use it in " "static mode if you would like to translate to static graph.") return dygraph_func static_func = convert_function_with_cache(dygraph_func) @@ -299,7 +302,7 @@ class ProgramTranslator(object): def get_program(self, dygraph_func, *args, **kwargs): """ - Returns the translated static program and input/output variables from + Return the translated static program and input/output variables from dygraph function. """ if in_dygraph_mode(): @@ -315,7 +318,7 @@ class ProgramTranslator(object): def get_code(self, dygraph_func): """ - Returns the translated static function code from dygraph code + Return the translated static function code from dygraph code """ # Get AST from dygraph function raw_code = inspect.getsource(dygraph_func) @@ -332,7 +335,7 @@ class ProgramTranslator(object): def run(self, *args, **kwargs): """ - Executes main_program and returns output Tensors. + Execute main_program and returns output Tensors. """ feed_dict, fetch_list = self._prepare(args) @@ -343,18 +346,18 @@ class ProgramTranslator(object): return outputs - def set_optimizer(self, optimizer, loss_name): + def set_optimizer(self, optimizer, index_of_loss=0): """ - Supports to set or update the optimizer used to minimize loss. + Support to set or update the optimizer used to minimize loss. """ + check_type(index_of_loss, "index_of_loss", int, + "ProgramTranslator.set_optimizer") self._check_cache_valid() - self._optimizer = optimizer - - if not isinstance(loss_name, six.string_types): + if self._optimizer and self._loss_name: raise ValueError( - "Type of input loss_name should type(str), but received {}.". - format(type(loss_name))) - self._loss_name = loss_name + "{} for {} has already been set before. Please confirm not to call `set_optimizer` in for loop. ". + format(self._optimizer, self._loss_name)) + self._optimizer_info = (optimizer, index_of_loss) def save_inference_model(self, dirname, feed=None, fetch=None): """ @@ -377,16 +380,16 @@ class ProgramTranslator(object): def _prepare(self, args): """ - Prepares with feed_dict, fetch_list, optimizer and initialize vars + Prepare with feed_dict, fetch_list, optimizer and initialize vars by running startup_program. """ - # Updates batch_data for feed_dict + # Update batch_data for feed_dict feed_dict = self._update_batch_data(args) fetch_list = self._program_cache.outputs - # Adds optimizer if needed. - if self._optimizer and not self._already_minimized: + # Add optimizer if needed. + if self._optimizer_info and self._optimizer is None: self._add_optimizer() if self._need_startup: @@ -397,7 +400,7 @@ class ProgramTranslator(object): def _check_cache_valid(self): """ - Checks whether the current program is consistent with `default_main_program`. + Check whether the current program is consistent with `default_main_program`. In some models and unittest, program will be switched frequently by `program_guard`. If does, the cached program and other properties are not available and should be reset. """ @@ -408,7 +411,7 @@ class ProgramTranslator(object): def _update_batch_data(self, args): """ - Updates cached batch data while training program. + Update cached batch data while training program. """ feed_name_to_idx = self._program_cache.feed_name_to_idx feed_vars = self._program_cache.inputs @@ -421,27 +424,40 @@ class ProgramTranslator(object): def _add_optimizer(self): """ - Supports to set or update the optimizer used to minimize loss. + Support to set or update the optimizer used to minimize loss. """ + optimizer, index_of_loss = self._optimizer_info + + outputs = self._program_cache.outputs + outputs = [outputs] if not isinstance(outputs, + (list, tuple)) else outputs + + assert abs(index_of_loss) < len(outputs), \ + "index_of_loss: {} shall not exceed the length of outputs: {}.".format( + index_of_loss, len(outputs)) + + loss_var = outputs[index_of_loss] + check_type(loss_var, "loss_var", framework.Variable, + "ProgramTranslator._add_optimizer") + main_program = self._program_cache.main_program startup_program = self._program_cache.startup_program all_vars = main_program.block(0).vars - loss_var = all_vars.get(self._loss_name, None) - if loss_var is None: + if all_vars.get(loss_var.name, None) is None: raise ValueError( - "Can't find {} in main_program, please confirm whether the loss input is correct" - .format(self._loss_name)) - # Adds optimizer to minimize loss + "Can't find {} in main_program, please confirm whether the input loss is correct." + .format(loss_var.name)) + # Add optimizer to minimize loss with framework.program_guard(main_program, startup_program): - self._optimizer.minimize(loss_var) + optimizer.minimize(loss_var) - # Avoids to set optimizer repeatedly. - self._already_minimized = True + self._optimizer = optimizer + self._loss_name = loss_var.name def get_program_cache(self): """ - Returns the ProgramCache instance. + Return the ProgramCache instance. """ self._check_cache_valid() return self._program_cache diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py index 588dd0a5f1..3b7cdae78a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py @@ -76,9 +76,8 @@ class TestCacheProgramWithOptimizer(unittest.TestCase): static_net = self.dygraph_class() adam = fluid.optimizer.AdamOptimizer(learning_rate=0.001) # set optimizer - # TODO: Need a better interfaces to set optimizer. program_translator = ProgramTranslator() - program_translator.set_optimizer(adam, 'avg_loss') + program_translator.set_optimizer(adam, index_of_loss=1) for batch_id in range(self.batch_num): pred, avg_loss = static_net(self.data) @@ -110,6 +109,20 @@ class TestCacheProgramWithOptimizer(unittest.TestCase): msg='dygraph is {}\n static_res is \n{}'.format(dygraph_loss, static_loss)) + def test_exception(self): + main_program = fluid.Program() + loss_data = [] + with fluid.program_guard(main_program): + static_net = self.dygraph_class() + adam = fluid.optimizer.AdamOptimizer(learning_rate=0.001) + # set optimizer + program_translator = ProgramTranslator() + + with self.assertRaisesRegexp(ValueError, "has already been set"): + for batch_id in range(self.batch_num): + program_translator.set_optimizer(adam, index_of_loss=1) + static_net(self.data) + def simple_func(x): inputs = fluid.dygraph.to_variable(x) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py index 5dc806cd59..3c2e34aec3 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py @@ -58,7 +58,7 @@ class Linear(fluid.dygraph.Layer): def forward(self, x): inputs = fluid.dygraph.to_variable(x) pre = self.fc(inputs) - loss = fluid.layers.mean(pre, name='avg_loss') + loss = fluid.layers.mean(pre) return pre, loss diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py index 1a0266c7b2..3933cd02f2 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py @@ -39,7 +39,7 @@ class SimpleFcLayer(fluid.dygraph.Layer): x = fluid.dygraph.to_variable(x) y = self._linear(x) z = self._linear(y) - out = fluid.layers.mean(z, name='mean') + out = fluid.layers.mean(z) return out @@ -53,7 +53,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase): program_translator = ProgramTranslator.get_instance() program_cache = ProgramTranslator().get_program_cache adam = fluid.optimizer.SGD(learning_rate=0.001) - program_translator.set_optimizer(adam, 'mean') + program_translator.set_optimizer(adam, index_of_loss=0) for i in range(5): out = layer(x) -- GitLab