未验证 提交 d37cd740 编写于 作者: A Aurelius84 提交者: GitHub

Polish set_optimizer Interface (#23588)

上级 f301eb7f
......@@ -20,6 +20,7 @@ import six
import textwrap
import threading
import warnings
from collections import defaultdict
from paddle.fluid import framework
from paddle.fluid import core, executor
......@@ -28,6 +29,7 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStat
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.data_feeder import check_type
__all__ = ['ProgramTranslator', 'convert_function_with_cache']
......@@ -261,19 +263,20 @@ class ProgramTranslator(object):
else:
self._exe = exe
self._program_cache = ProgramCache()
self._optimizer_info = None
self._optimizer = None
self._already_minimized = False
self._loss_name = None
# Once main_program is changed, should run startup_program.
self._need_startup = True
def get_output(self, dygraph_func, *args, **kwargs):
"""
Returns the output tensors for dygraph function and its arguments
Return the output tensors for dygraph function and its arguments
"""
if in_dygraph_mode():
warnings.warn(
"The ProgramTranslator.get_output doesn't work in dygraph "
"mode. We will just return dygraph output. Use the it in "
"mode. We will just return dygraph output. Use it in "
"static mode if you would like to translate to static graph.")
return dygraph_func(*args, **kwargs)
......@@ -286,12 +289,12 @@ class ProgramTranslator(object):
def get_func(self, dygraph_func):
"""
Returns the translated static function from dygraph function
Return the translated static function from dygraph function
"""
if in_dygraph_mode():
warnings.warn(
"The ProgramTranslator.get_func doesn't work in dygraph "
"mode. We will just return dygraph function. Use the it in "
"mode. We will just return dygraph function. Use it in "
"static mode if you would like to translate to static graph.")
return dygraph_func
static_func = convert_function_with_cache(dygraph_func)
......@@ -299,7 +302,7 @@ class ProgramTranslator(object):
def get_program(self, dygraph_func, *args, **kwargs):
"""
Returns the translated static program and input/output variables from
Return the translated static program and input/output variables from
dygraph function.
"""
if in_dygraph_mode():
......@@ -315,7 +318,7 @@ class ProgramTranslator(object):
def get_code(self, dygraph_func):
"""
Returns the translated static function code from dygraph code
Return the translated static function code from dygraph code
"""
# Get AST from dygraph function
raw_code = inspect.getsource(dygraph_func)
......@@ -332,7 +335,7 @@ class ProgramTranslator(object):
def run(self, *args, **kwargs):
"""
Executes main_program and returns output Tensors.
Execute main_program and returns output Tensors.
"""
feed_dict, fetch_list = self._prepare(args)
......@@ -343,18 +346,18 @@ class ProgramTranslator(object):
return outputs
def set_optimizer(self, optimizer, loss_name):
def set_optimizer(self, optimizer, index_of_loss=0):
"""
Supports to set or update the optimizer used to minimize loss.
Support to set or update the optimizer used to minimize loss.
"""
check_type(index_of_loss, "index_of_loss", int,
"ProgramTranslator.set_optimizer")
self._check_cache_valid()
self._optimizer = optimizer
if not isinstance(loss_name, six.string_types):
if self._optimizer and self._loss_name:
raise ValueError(
"Type of input loss_name should type(str), but received {}.".
format(type(loss_name)))
self._loss_name = loss_name
"{} for {} has already been set before. Please confirm not to call `set_optimizer` in for loop. ".
format(self._optimizer, self._loss_name))
self._optimizer_info = (optimizer, index_of_loss)
def save_inference_model(self, dirname, feed=None, fetch=None):
"""
......@@ -377,16 +380,16 @@ class ProgramTranslator(object):
def _prepare(self, args):
"""
Prepares with feed_dict, fetch_list, optimizer and initialize vars
Prepare with feed_dict, fetch_list, optimizer and initialize vars
by running startup_program.
"""
# Updates batch_data for feed_dict
# Update batch_data for feed_dict
feed_dict = self._update_batch_data(args)
fetch_list = self._program_cache.outputs
# Adds optimizer if needed.
if self._optimizer and not self._already_minimized:
# Add optimizer if needed.
if self._optimizer_info and self._optimizer is None:
self._add_optimizer()
if self._need_startup:
......@@ -397,7 +400,7 @@ class ProgramTranslator(object):
def _check_cache_valid(self):
"""
Checks whether the current program is consistent with `default_main_program`.
Check whether the current program is consistent with `default_main_program`.
In some models and unittest, program will be switched frequently by `program_guard`.
If does, the cached program and other properties are not available and should be reset.
"""
......@@ -408,7 +411,7 @@ class ProgramTranslator(object):
def _update_batch_data(self, args):
"""
Updates cached batch data while training program.
Update cached batch data while training program.
"""
feed_name_to_idx = self._program_cache.feed_name_to_idx
feed_vars = self._program_cache.inputs
......@@ -421,27 +424,40 @@ class ProgramTranslator(object):
def _add_optimizer(self):
"""
Supports to set or update the optimizer used to minimize loss.
Support to set or update the optimizer used to minimize loss.
"""
optimizer, index_of_loss = self._optimizer_info
outputs = self._program_cache.outputs
outputs = [outputs] if not isinstance(outputs,
(list, tuple)) else outputs
assert abs(index_of_loss) < len(outputs), \
"index_of_loss: {} shall not exceed the length of outputs: {}.".format(
index_of_loss, len(outputs))
loss_var = outputs[index_of_loss]
check_type(loss_var, "loss_var", framework.Variable,
"ProgramTranslator._add_optimizer")
main_program = self._program_cache.main_program
startup_program = self._program_cache.startup_program
all_vars = main_program.block(0).vars
loss_var = all_vars.get(self._loss_name, None)
if loss_var is None:
if all_vars.get(loss_var.name, None) is None:
raise ValueError(
"Can't find {} in main_program, please confirm whether the loss input is correct"
.format(self._loss_name))
# Adds optimizer to minimize loss
"Can't find {} in main_program, please confirm whether the input loss is correct."
.format(loss_var.name))
# Add optimizer to minimize loss
with framework.program_guard(main_program, startup_program):
self._optimizer.minimize(loss_var)
optimizer.minimize(loss_var)
# Avoids to set optimizer repeatedly.
self._already_minimized = True
self._optimizer = optimizer
self._loss_name = loss_var.name
def get_program_cache(self):
"""
Returns the ProgramCache instance.
Return the ProgramCache instance.
"""
self._check_cache_valid()
return self._program_cache
......
......@@ -76,9 +76,8 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
static_net = self.dygraph_class()
adam = fluid.optimizer.AdamOptimizer(learning_rate=0.001)
# set optimizer
# TODO: Need a better interfaces to set optimizer.
program_translator = ProgramTranslator()
program_translator.set_optimizer(adam, 'avg_loss')
program_translator.set_optimizer(adam, index_of_loss=1)
for batch_id in range(self.batch_num):
pred, avg_loss = static_net(self.data)
......@@ -110,6 +109,20 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
msg='dygraph is {}\n static_res is \n{}'.format(dygraph_loss,
static_loss))
def test_exception(self):
main_program = fluid.Program()
loss_data = []
with fluid.program_guard(main_program):
static_net = self.dygraph_class()
adam = fluid.optimizer.AdamOptimizer(learning_rate=0.001)
# set optimizer
program_translator = ProgramTranslator()
with self.assertRaisesRegexp(ValueError, "has already been set"):
for batch_id in range(self.batch_num):
program_translator.set_optimizer(adam, index_of_loss=1)
static_net(self.data)
def simple_func(x):
inputs = fluid.dygraph.to_variable(x)
......
......@@ -58,7 +58,7 @@ class Linear(fluid.dygraph.Layer):
def forward(self, x):
inputs = fluid.dygraph.to_variable(x)
pre = self.fc(inputs)
loss = fluid.layers.mean(pre, name='avg_loss')
loss = fluid.layers.mean(pre)
return pre, loss
......
......@@ -39,7 +39,7 @@ class SimpleFcLayer(fluid.dygraph.Layer):
x = fluid.dygraph.to_variable(x)
y = self._linear(x)
z = self._linear(y)
out = fluid.layers.mean(z, name='mean')
out = fluid.layers.mean(z)
return out
......@@ -53,7 +53,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
program_translator = ProgramTranslator.get_instance()
program_cache = ProgramTranslator().get_program_cache
adam = fluid.optimizer.SGD(learning_rate=0.001)
program_translator.set_optimizer(adam, 'mean')
program_translator.set_optimizer(adam, index_of_loss=0)
for i in range(5):
out = layer(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册