未验证 提交 d37cd740 编写于 作者: A Aurelius84 提交者: GitHub

Polish set_optimizer Interface (#23588)

上级 f301eb7f
...@@ -20,6 +20,7 @@ import six ...@@ -20,6 +20,7 @@ import six
import textwrap import textwrap
import threading import threading
import warnings import warnings
from collections import defaultdict
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid import core, executor from paddle.fluid import core, executor
...@@ -28,6 +29,7 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStat ...@@ -28,6 +29,7 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStat
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.data_feeder import check_type
__all__ = ['ProgramTranslator', 'convert_function_with_cache'] __all__ = ['ProgramTranslator', 'convert_function_with_cache']
...@@ -261,19 +263,20 @@ class ProgramTranslator(object): ...@@ -261,19 +263,20 @@ class ProgramTranslator(object):
else: else:
self._exe = exe self._exe = exe
self._program_cache = ProgramCache() self._program_cache = ProgramCache()
self._optimizer_info = None
self._optimizer = None self._optimizer = None
self._already_minimized = False self._loss_name = None
# Once main_program is changed, should run startup_program. # Once main_program is changed, should run startup_program.
self._need_startup = True self._need_startup = True
def get_output(self, dygraph_func, *args, **kwargs): def get_output(self, dygraph_func, *args, **kwargs):
""" """
Returns the output tensors for dygraph function and its arguments Return the output tensors for dygraph function and its arguments
""" """
if in_dygraph_mode(): if in_dygraph_mode():
warnings.warn( warnings.warn(
"The ProgramTranslator.get_output doesn't work in dygraph " "The ProgramTranslator.get_output doesn't work in dygraph "
"mode. We will just return dygraph output. Use the it in " "mode. We will just return dygraph output. Use it in "
"static mode if you would like to translate to static graph.") "static mode if you would like to translate to static graph.")
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)
...@@ -286,12 +289,12 @@ class ProgramTranslator(object): ...@@ -286,12 +289,12 @@ class ProgramTranslator(object):
def get_func(self, dygraph_func): def get_func(self, dygraph_func):
""" """
Returns the translated static function from dygraph function Return the translated static function from dygraph function
""" """
if in_dygraph_mode(): if in_dygraph_mode():
warnings.warn( warnings.warn(
"The ProgramTranslator.get_func doesn't work in dygraph " "The ProgramTranslator.get_func doesn't work in dygraph "
"mode. We will just return dygraph function. Use the it in " "mode. We will just return dygraph function. Use it in "
"static mode if you would like to translate to static graph.") "static mode if you would like to translate to static graph.")
return dygraph_func return dygraph_func
static_func = convert_function_with_cache(dygraph_func) static_func = convert_function_with_cache(dygraph_func)
...@@ -299,7 +302,7 @@ class ProgramTranslator(object): ...@@ -299,7 +302,7 @@ class ProgramTranslator(object):
def get_program(self, dygraph_func, *args, **kwargs): def get_program(self, dygraph_func, *args, **kwargs):
""" """
Returns the translated static program and input/output variables from Return the translated static program and input/output variables from
dygraph function. dygraph function.
""" """
if in_dygraph_mode(): if in_dygraph_mode():
...@@ -315,7 +318,7 @@ class ProgramTranslator(object): ...@@ -315,7 +318,7 @@ class ProgramTranslator(object):
def get_code(self, dygraph_func): def get_code(self, dygraph_func):
""" """
Returns the translated static function code from dygraph code Return the translated static function code from dygraph code
""" """
# Get AST from dygraph function # Get AST from dygraph function
raw_code = inspect.getsource(dygraph_func) raw_code = inspect.getsource(dygraph_func)
...@@ -332,7 +335,7 @@ class ProgramTranslator(object): ...@@ -332,7 +335,7 @@ class ProgramTranslator(object):
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
""" """
Executes main_program and returns output Tensors. Execute main_program and returns output Tensors.
""" """
feed_dict, fetch_list = self._prepare(args) feed_dict, fetch_list = self._prepare(args)
...@@ -343,18 +346,18 @@ class ProgramTranslator(object): ...@@ -343,18 +346,18 @@ class ProgramTranslator(object):
return outputs return outputs
def set_optimizer(self, optimizer, loss_name): def set_optimizer(self, optimizer, index_of_loss=0):
""" """
Supports to set or update the optimizer used to minimize loss. Support to set or update the optimizer used to minimize loss.
""" """
check_type(index_of_loss, "index_of_loss", int,
"ProgramTranslator.set_optimizer")
self._check_cache_valid() self._check_cache_valid()
self._optimizer = optimizer if self._optimizer and self._loss_name:
if not isinstance(loss_name, six.string_types):
raise ValueError( raise ValueError(
"Type of input loss_name should type(str), but received {}.". "{} for {} has already been set before. Please confirm not to call `set_optimizer` in for loop. ".
format(type(loss_name))) format(self._optimizer, self._loss_name))
self._loss_name = loss_name self._optimizer_info = (optimizer, index_of_loss)
def save_inference_model(self, dirname, feed=None, fetch=None): def save_inference_model(self, dirname, feed=None, fetch=None):
""" """
...@@ -377,16 +380,16 @@ class ProgramTranslator(object): ...@@ -377,16 +380,16 @@ class ProgramTranslator(object):
def _prepare(self, args): def _prepare(self, args):
""" """
Prepares with feed_dict, fetch_list, optimizer and initialize vars Prepare with feed_dict, fetch_list, optimizer and initialize vars
by running startup_program. by running startup_program.
""" """
# Updates batch_data for feed_dict # Update batch_data for feed_dict
feed_dict = self._update_batch_data(args) feed_dict = self._update_batch_data(args)
fetch_list = self._program_cache.outputs fetch_list = self._program_cache.outputs
# Adds optimizer if needed. # Add optimizer if needed.
if self._optimizer and not self._already_minimized: if self._optimizer_info and self._optimizer is None:
self._add_optimizer() self._add_optimizer()
if self._need_startup: if self._need_startup:
...@@ -397,7 +400,7 @@ class ProgramTranslator(object): ...@@ -397,7 +400,7 @@ class ProgramTranslator(object):
def _check_cache_valid(self): def _check_cache_valid(self):
""" """
Checks whether the current program is consistent with `default_main_program`. Check whether the current program is consistent with `default_main_program`.
In some models and unittest, program will be switched frequently by `program_guard`. In some models and unittest, program will be switched frequently by `program_guard`.
If does, the cached program and other properties are not available and should be reset. If does, the cached program and other properties are not available and should be reset.
""" """
...@@ -408,7 +411,7 @@ class ProgramTranslator(object): ...@@ -408,7 +411,7 @@ class ProgramTranslator(object):
def _update_batch_data(self, args): def _update_batch_data(self, args):
""" """
Updates cached batch data while training program. Update cached batch data while training program.
""" """
feed_name_to_idx = self._program_cache.feed_name_to_idx feed_name_to_idx = self._program_cache.feed_name_to_idx
feed_vars = self._program_cache.inputs feed_vars = self._program_cache.inputs
...@@ -421,27 +424,40 @@ class ProgramTranslator(object): ...@@ -421,27 +424,40 @@ class ProgramTranslator(object):
def _add_optimizer(self): def _add_optimizer(self):
""" """
Supports to set or update the optimizer used to minimize loss. Support to set or update the optimizer used to minimize loss.
""" """
optimizer, index_of_loss = self._optimizer_info
outputs = self._program_cache.outputs
outputs = [outputs] if not isinstance(outputs,
(list, tuple)) else outputs
assert abs(index_of_loss) < len(outputs), \
"index_of_loss: {} shall not exceed the length of outputs: {}.".format(
index_of_loss, len(outputs))
loss_var = outputs[index_of_loss]
check_type(loss_var, "loss_var", framework.Variable,
"ProgramTranslator._add_optimizer")
main_program = self._program_cache.main_program main_program = self._program_cache.main_program
startup_program = self._program_cache.startup_program startup_program = self._program_cache.startup_program
all_vars = main_program.block(0).vars all_vars = main_program.block(0).vars
loss_var = all_vars.get(self._loss_name, None)
if loss_var is None: if all_vars.get(loss_var.name, None) is None:
raise ValueError( raise ValueError(
"Can't find {} in main_program, please confirm whether the loss input is correct" "Can't find {} in main_program, please confirm whether the input loss is correct."
.format(self._loss_name)) .format(loss_var.name))
# Adds optimizer to minimize loss # Add optimizer to minimize loss
with framework.program_guard(main_program, startup_program): with framework.program_guard(main_program, startup_program):
self._optimizer.minimize(loss_var) optimizer.minimize(loss_var)
# Avoids to set optimizer repeatedly. self._optimizer = optimizer
self._already_minimized = True self._loss_name = loss_var.name
def get_program_cache(self): def get_program_cache(self):
""" """
Returns the ProgramCache instance. Return the ProgramCache instance.
""" """
self._check_cache_valid() self._check_cache_valid()
return self._program_cache return self._program_cache
......
...@@ -76,9 +76,8 @@ class TestCacheProgramWithOptimizer(unittest.TestCase): ...@@ -76,9 +76,8 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
static_net = self.dygraph_class() static_net = self.dygraph_class()
adam = fluid.optimizer.AdamOptimizer(learning_rate=0.001) adam = fluid.optimizer.AdamOptimizer(learning_rate=0.001)
# set optimizer # set optimizer
# TODO: Need a better interfaces to set optimizer.
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
program_translator.set_optimizer(adam, 'avg_loss') program_translator.set_optimizer(adam, index_of_loss=1)
for batch_id in range(self.batch_num): for batch_id in range(self.batch_num):
pred, avg_loss = static_net(self.data) pred, avg_loss = static_net(self.data)
...@@ -110,6 +109,20 @@ class TestCacheProgramWithOptimizer(unittest.TestCase): ...@@ -110,6 +109,20 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
msg='dygraph is {}\n static_res is \n{}'.format(dygraph_loss, msg='dygraph is {}\n static_res is \n{}'.format(dygraph_loss,
static_loss)) static_loss))
def test_exception(self):
main_program = fluid.Program()
loss_data = []
with fluid.program_guard(main_program):
static_net = self.dygraph_class()
adam = fluid.optimizer.AdamOptimizer(learning_rate=0.001)
# set optimizer
program_translator = ProgramTranslator()
with self.assertRaisesRegexp(ValueError, "has already been set"):
for batch_id in range(self.batch_num):
program_translator.set_optimizer(adam, index_of_loss=1)
static_net(self.data)
def simple_func(x): def simple_func(x):
inputs = fluid.dygraph.to_variable(x) inputs = fluid.dygraph.to_variable(x)
......
...@@ -58,7 +58,7 @@ class Linear(fluid.dygraph.Layer): ...@@ -58,7 +58,7 @@ class Linear(fluid.dygraph.Layer):
def forward(self, x): def forward(self, x):
inputs = fluid.dygraph.to_variable(x) inputs = fluid.dygraph.to_variable(x)
pre = self.fc(inputs) pre = self.fc(inputs)
loss = fluid.layers.mean(pre, name='avg_loss') loss = fluid.layers.mean(pre)
return pre, loss return pre, loss
......
...@@ -39,7 +39,7 @@ class SimpleFcLayer(fluid.dygraph.Layer): ...@@ -39,7 +39,7 @@ class SimpleFcLayer(fluid.dygraph.Layer):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
y = self._linear(x) y = self._linear(x)
z = self._linear(y) z = self._linear(y)
out = fluid.layers.mean(z, name='mean') out = fluid.layers.mean(z)
return out return out
...@@ -53,7 +53,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase): ...@@ -53,7 +53,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
program_translator = ProgramTranslator.get_instance() program_translator = ProgramTranslator.get_instance()
program_cache = ProgramTranslator().get_program_cache program_cache = ProgramTranslator().get_program_cache
adam = fluid.optimizer.SGD(learning_rate=0.001) adam = fluid.optimizer.SGD(learning_rate=0.001)
program_translator.set_optimizer(adam, 'mean') program_translator.set_optimizer(adam, index_of_loss=0)
for i in range(5): for i in range(5):
out = layer(x) out = layer(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册