提交 4cf61262 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Improve TFGAN documentation.

PiperOrigin-RevId: 170940188
上级 0068086b
......@@ -14,10 +14,41 @@
# ==============================================================================
"""TFGAN utilities for loss functions that accept GANModel namedtuples.
Example:
The losses and penalties in this file all correspond to losses in
`losses_impl.py`. Losses in that file take individual arguments, whereas in this
file they take a `GANModel` tuple. For example:
losses_impl.py:
```python
def wasserstein_discriminator_loss(
discriminator_real_outputs,
discriminator_gen_outputs,
real_weights=1.0,
generated_weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False)
```
tuple_losses_impl.py:
```python
def wasserstein_discriminator_loss(
gan_model,
real_weights=1.0,
generated_weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False)
```
Example usage:
```python
# `tfgan.losses.args` losses take individual arguments.
w_loss = tfgan.losses.args.wasserstein_discriminator_loss(
# `tfgan.losses.wargs` losses take individual arguments.
w_loss = tfgan.losses.wargs.wasserstein_discriminator_loss(
discriminator_real_outputs,
discriminator_gen_outputs)
......
......@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Named tuples for TFGAN."""
"""Named tuples for TFGAN.
TFGAN training occurs in four steps, and each step communicates with the next
step via one of these named tuples. At each step, you can either use a TFGAN
helper function in `train.py`, or you can manually construct a tuple.
"""
from __future__ import absolute_import
from __future__ import division
......
......@@ -14,7 +14,17 @@
# ==============================================================================
"""The TFGAN project provides a lightweight GAN training/testing framework.
See examples in `tensorflow_models` for details on how to use.
This file contains the core helper functions to create and train a GAN model.
See the README or examples in `tensorflow_models` for details on how to use.
TFGAN training occurs in four steps:
1) Create a model
2) Add a loss
3) Create train ops
4) Run the train ops
The functions in this file are organized around these four steps. Each function
corresponds to one of the steps.
"""
from __future__ import absolute_import
......@@ -51,16 +61,6 @@ __all__ = [
]
def _convert_tensor_or_l_or_d(tensor_or_l_or_d):
"""Convert input, list of inputs, or dictionary of inputs to Tensors."""
if isinstance(tensor_or_l_or_d, (list, tuple)):
return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d]
elif isinstance(tensor_or_l_or_d, dict):
return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()}
else:
return ops.convert_to_tensor(tensor_or_l_or_d)
def gan_model(
# Lambdas defining models.
generator_fn,
......@@ -133,20 +133,6 @@ def gan_model(
discriminator_fn)
def _validate_distributions(distributions_l, noise_l):
if not isinstance(distributions_l, (tuple, list)):
raise ValueError('`predicted_distributions` must be a list. Instead, found '
'%s.' % type(distributions_l))
for dist in distributions_l:
if not isinstance(dist, ds.Distribution):
raise ValueError('Every element in `predicted_distributions` must be a '
'`tf.Distribution`. Instead, found %s.' % type(dist))
if len(distributions_l) != len(noise_l):
raise ValueError('Length of `predicted_distributions` %i must be the same '
'as the length of structured noise %i.' %
(len(distributions_l), len(noise_l)))
def infogan_model(
# Lambdas defining models.
generator_fn,
......@@ -231,16 +217,6 @@ def infogan_model(
predicted_distributions)
def _validate_acgan_discriminator_outputs(discriminator_output):
try:
a, b = discriminator_output
except (TypeError, ValueError):
raise TypeError(
'A discriminator function for ACGAN must output a tuple '
'consisting of (discrimination logits, classification logits).')
return a, b
def acgan_model(
# Lambdas defining models.
generator_fn,
......@@ -252,6 +228,7 @@ def acgan_model(
# Optional scopes.
generator_scope='Generator',
discriminator_scope='Discriminator',
# Options.
check_shapes=True):
"""Returns an ACGANModel contains all the pieces needed for ACGAN training.
......@@ -497,11 +474,10 @@ def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True):
def gan_train_ops(
model, # GANModel
loss, # GANLoss
model,
loss,
generator_optimizer,
discriminator_optimizer,
# Optional check flags.
check_for_unused_update_ops=True,
# Optional args to pass directly to the `create_train_op`.
**kwargs):
......@@ -801,3 +777,40 @@ def get_sequential_train_steps(
return gen_loss + dis_loss, should_stop
return sequential_train_steps
# Helpers
def _convert_tensor_or_l_or_d(tensor_or_l_or_d):
"""Convert input, list of inputs, or dictionary of inputs to Tensors."""
if isinstance(tensor_or_l_or_d, (list, tuple)):
return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d]
elif isinstance(tensor_or_l_or_d, dict):
return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()}
else:
return ops.convert_to_tensor(tensor_or_l_or_d)
def _validate_distributions(distributions_l, noise_l):
if not isinstance(distributions_l, (tuple, list)):
raise ValueError('`predicted_distributions` must be a list. Instead, found '
'%s.' % type(distributions_l))
for dist in distributions_l:
if not isinstance(dist, ds.Distribution):
raise ValueError('Every element in `predicted_distributions` must be a '
'`tf.Distribution`. Instead, found %s.' % type(dist))
if len(distributions_l) != len(noise_l):
raise ValueError('Length of `predicted_distributions` %i must be the same '
'as the length of structured noise %i.' %
(len(distributions_l), len(noise_l)))
def _validate_acgan_discriminator_outputs(discriminator_output):
try:
a, b = discriminator_output
except (TypeError, ValueError):
raise TypeError(
'A discriminator function for ACGAN must output a tuple '
'consisting of (discrimination logits, classification logits).')
return a, b
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册