提交 f9bd951e 编写于 作者: L Lukasz Kaiser 提交者: GitHub

Merge pull request #1382 from rsepassi/adv_text

Add adversarial text model
# Binaries
# ==============================================================================
py_binary(
name = "evaluate",
srcs = ["evaluate.py"],
deps = [
":graphs",
],
)
py_binary(
name = "train_classifier",
srcs = ["train_classifier.py"],
deps = [
":graphs",
":train_utils",
],
)
py_binary(
name = "pretrain",
srcs = [
"pretrain.py",
],
deps = [
":graphs",
":train_utils",
],
)
# Libraries
# ==============================================================================
py_library(
name = "graphs",
srcs = ["graphs.py"],
deps = [
":adversarial_losses",
":inputs",
":layers",
],
)
py_library(
name = "adversarial_losses",
srcs = ["adversarial_losses.py"],
)
py_library(
name = "inputs",
srcs = ["inputs.py"],
deps = [
"//adversarial_text/data:data_utils",
],
)
py_library(
name = "layers",
srcs = ["layers.py"],
)
py_library(
name = "train_utils",
srcs = ["train_utils.py"],
)
# Tests
# ==============================================================================
py_test(
name = "graphs_test",
size = "large",
srcs = ["graphs_test.py"],
deps = [
":graphs",
"//adversarial_text/data:data_utils",
],
)
# Adversarial Text Classification
Code for *Adversarial Training Methods for Semi-Supervised Text Classification*
(https://arxiv.org/abs/1605.07725) and *Semi-Supervised Sequence Learning*
(https://arxiv.org/abs/1511.01432).
## Requirements
* Bazel ([install](https://bazel.build/versions/master/docs/install.html))
* TensorFlow >= v1.1
## End-to-end IMDB Sentiment Classification
### Fetch data
```
$ wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz \
-O /tmp/imdb.tar.gz
$ tar -xf /tmp/imdb.tar.gz -C /tmp
```
The directory `/tmp/aclImdb` contains the raw IMDB data.
### Generate vocabulary
```
$ IMDB_DATA_DIR=/tmp/imdb
$ bazel run data:gen_vocab -- \
--output_dir=$IMDB_DATA_DIR \
--dataset=imdb \
--imdb_input_dir=/tmp/aclImdb \
--lowercase=False
```
Vocabulary and frequency files will be generated in `$IMDB_DATA_DIR`.
###  Generate training, validation, and test data
```
$ bazel run data:gen_data -- \
--output_dir=$IMDB_DATA_DIR \
--dataset=imdb \
--imdb_input_dir=/tmp/aclImdb \
--lowercase=False \
--label_gain=False
```
`$IMDB_DATA_DIR` contains TFRecords files.
### Pretrain IMDB Language Model
```
$ PRETRAIN_DIR=/tmp/models/imdb_pretrain
$ bazel run :pretrain -- \
--train_dir=$PRETRAIN_DIR \
--data_dir=$IMDB_DATA_DIR \
--vocab_size=86934 \
--embedding_dims=256 \
--rnn_cell_size=1024 \
--num_candidate_samples=1024 \
--optimizer=adam \
--batch_size=256 \
--learning_rate=0.001 \
--learning_rate_decay_factor=0.9999 \
--max_steps=100000 \
--max_grad_norm=1.0 \
--num_timesteps=400 \
--keep_prob_emb=0.5 \
--normalize_embeddings
```
`$PRETRAIN_DIR` contains checkpoints of the pretrained language model.
### Train classifier
Most flags stay the same, save for the removal of candidate sampling and the
addition of `pretrained_model_dir`, from which the classifier will load the
pretrained embedding and LSTM variables, and flags related to adversarial
training and classification.
```
$ TRAIN_DIR=/tmp/models/imdb_classify
$ bazel run :train_classifier -- \
--train_dir=$TRAIN_DIR \
--pretrained_model_dir=$PRETRAIN_DIR \
--data_dir=$IMDB_DATA_DIR \
--vocab_size=86934 \
--embedding_dims=256 \
--rnn_cell_size=1024 \
--cl_num_layers=1 \
--cl_hidden_size=30 \
--optimizer=adam \
--batch_size=64 \
--learning_rate=0.0005 \
--learning_rate_decay_factor=0.9998 \
--max_steps=15000 \
--max_grad_norm=1.0 \
--num_timesteps=400 \
--keep_prob_emb=0.5 \
--normalize_embeddings \
--adv_training_method=vat
```
### Evaluate on test data
```
$ EVAL_DIR=/tmp/models/imdb_eval
$ bazel run :evaluate -- \
--eval_dir=$EVAL_DIR \
--checkpoint_dir=$TRAIN_DIR \
--eval_data=test \
--run_once \
--num_examples=25000 \
--data_dir=$IMDB_DATA_DIR \
--vocab_size=86934 \
--embedding_dims=256 \
--rnn_cell_size=1024 \
--batch_size=256 \
--num_timesteps=400 \
--normalize_embeddings
```
## Code Overview
The main entry points are the binaries listed below. Each training binary builds
a `VatxtModel`, defined in `graphs.py`, which in turn uses graph building blocks
defined in `inputs.py` (defines input data reading and parsing), `layers.py`
(defines core model components), and `adversarial_losses.py` (defines
adversarial training losses). The training loop itself is defined in
`train_utils.py`.
### Binaries
* Pretraining: `pretrain.py`
* Classifier Training: `train_classifier.py`
* Evaluation: `evaluate.py`
### Command-Line Flags
Flags related to distributed training and the training loop itself are defined
in `train_utils.py`.
Flags related to model hyperparameters are defined in `graphs.py`.
Flags related to adversarial training are defined in `adversarial_losses.py`.
Flags particular to each job are defined in the main binary files.
### Data Generation
* Vocabulary generation: `gen_vocab.py`
* Data generation: `gen_data.py`
Command-line flags defined in `document_generators.py` control which dataset is
processed and how.
## Contact for Issues
* Ryan Sepassi, @rsepassi
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Adversarial losses for text models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
# Adversarial and virtual adversarial training parameters.
flags.DEFINE_float('perturb_norm_length', 0.1,
'Norm length of adversarial perturbation to be '
'optimized with validation')
# Virtual adversarial training parameters
flags.DEFINE_integer('num_power_iteration', 1, 'The number of power iteration')
flags.DEFINE_float('small_constant_for_finite_diff', 1e-3,
'Small constant for finite difference method')
# Parameters for building the graph
flags.DEFINE_string('adv_training_method', None,
'The flag which specifies training method. '
'"rp" : random perturbation training '
'"at" : adversarial training '
'"vat" : virtual adversarial training '
'"atvat" : at + vat ')
flags.DEFINE_float('adv_reg_coeff', 1.0,
'Regularization coefficient of adversarial loss.')
def random_perturbation_loss(embedded, length, loss_fn):
"""Adds noise to embeddings and recomputes classification loss."""
noise = tf.random_normal(shape=tf.shape(embedded))
perturb = _scale_l2(_mask_by_length(noise, length), FLAGS.perturb_norm_length)
return loss_fn(embedded + perturb)
def adversarial_loss(embedded, loss, loss_fn):
"""Adds gradient to embedding and recomputes classification loss."""
grad, = tf.gradients(
loss,
embedded,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
grad = tf.stop_gradient(grad)
perturb = _scale_l2(grad, FLAGS.perturb_norm_length)
return loss_fn(embedded + perturb)
def virtual_adversarial_loss(logits, embedded, inputs,
logits_from_embedding_fn):
"""Virtual adversarial loss.
Computes virtual adversarial perturbation by finite difference method and
power iteration, adds it to the embedding, and computes the KL divergence
between the new logits and the original logits.
Args:
logits: 2-D float Tensor, [num_timesteps*batch_size, m], where m=1 if
num_classes=2, otherwise m=num_classes.
embedded: 3-D float Tensor, [batch_size, num_timesteps, embedding_dim].
inputs: VatxtInput.
logits_from_embedding_fn: callable that takes embeddings and returns
classifier logits.
Returns:
kl: float scalar.
"""
# Stop gradient of logits. See https://arxiv.org/abs/1507.00677 for details.
logits = tf.stop_gradient(logits)
weights = _end_of_seq_mask(inputs.labels)
# shape(embedded) = (batch_size, num_timesteps, embedding_dim)
d = _mask_by_length(tf.random_normal(shape=tf.shape(embedded)), inputs.length)
# Perform finite difference method and power iteration.
# See Eq.(8) in the paper http://arxiv.org/pdf/1507.00677.pdf,
# Adding small noise to input and taking gradient with respect to the noise
# corresponds to 1 power iteration.
for _ in xrange(FLAGS.num_power_iteration):
d = _scale_l2(d, FLAGS.small_constant_for_finite_diff)
d_logits = logits_from_embedding_fn(embedded + d)
kl = _kl_divergence_with_logits(logits, d_logits, weights)
d, = tf.gradients(
kl,
d,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
d = tf.stop_gradient(d)
perturb = _scale_l2(
_mask_by_length(d, inputs.length), FLAGS.perturb_norm_length)
vadv_logits = logits_from_embedding_fn(embedded + perturb)
return _kl_divergence_with_logits(logits, vadv_logits, weights)
def random_perturbation_loss_bidir(embedded, length, loss_fn):
"""Adds noise to embeddings and recomputes classification loss."""
noise = [tf.random_normal(shape=tf.shape(emb)) for emb in embedded]
masked = [_mask_by_length(n, length) for n in noise]
scaled = [_scale_l2(m, FLAGS.perturb_norm_length) for m in masked]
return loss_fn([e + s for (e, s) in zip(embedded, scaled)])
def adversarial_loss_bidir(embedded, loss, loss_fn):
"""Adds gradient to embeddings and recomputes classification loss."""
grads = tf.gradients(
loss,
embedded,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
adv_exs = [
emb + _scale_l2(tf.stop_gradient(g), FLAGS.perturb_norm_length)
for emb, g in zip(embedded, grads)
]
return loss_fn(adv_exs)
def virtual_adversarial_loss_bidir(logits, embedded, inputs,
logits_from_embedding_fn):
"""Virtual adversarial loss for bidirectional models."""
logits = tf.stop_gradient(logits)
f_inputs, _ = inputs
weights = _end_of_seq_mask(f_inputs.labels)
perturbs = [
_mask_by_length(tf.random_normal(shape=tf.shape(emb)), f_inputs.length)
for emb in embedded
]
for _ in xrange(FLAGS.num_power_iteration):
perturbs = [
_scale_l2(d, FLAGS.small_constant_for_finite_diff) for d in perturbs
]
d_logits = logits_from_embedding_fn(
[emb + d for (emb, d) in zip(embedded, perturbs)])
kl = _kl_divergence_with_logits(logits, d_logits, weights)
perturbs = tf.gradients(
kl,
perturbs,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
perturbs = [tf.stop_gradient(d) for d in perturbs]
perturbs = [
_scale_l2(_mask_by_length(d, f_inputs.length), FLAGS.perturb_norm_length)
for d in perturbs
]
vadv_logits = logits_from_embedding_fn(
[emb + d for (emb, d) in zip(embedded, perturbs)])
return _kl_divergence_with_logits(logits, vadv_logits, weights)
def _mask_by_length(t, length):
"""Mask t, 3-D [batch, time, dim], by length, 1-D [batch,]."""
maxlen = t.get_shape().as_list()[1]
mask = tf.sequence_mask(length, maxlen=maxlen)
mask = tf.expand_dims(tf.cast(mask, tf.float32), -1)
# shape(mask) = (batch, num_timesteps, 1)
return t * mask
def _scale_l2(x, norm_length):
# shape(x) = (batch, num_timesteps, d)
x /= (1e-12 + tf.reduce_max(tf.abs(x), 2, keep_dims=True))
x_2 = tf.reduce_sum(tf.pow(x, 2), 2, keep_dims=True)
x /= tf.sqrt(1e-6 + x_2)
return norm_length * x
def _end_of_seq_mask(tokens):
"""Generate a mask for the EOS token (1.0 on EOS, 0.0 otherwise).
Args:
tokens: 1-D integer tensor [num_timesteps*batch_size]. Each element is an
id from the vocab.
Returns:
Float tensor same shape as tokens, whose values are 1.0 on the end of
sequence and 0.0 on the others.
"""
eos_id = FLAGS.vocab_size - 1
return tf.cast(tf.equal(tokens, eos_id), tf.float32)
def _kl_divergence_with_logits(q_logits, p_logits, weights):
"""Returns weighted KL divergence between distributions q and p.
Args:
q_logits: logits for 1st argument of KL divergence shape
[num_timesteps * batch_size, num_classes] if num_classes > 2, and
[num_timesteps * batch_size] if num_classes == 2.
p_logits: logits for 2nd argument of KL divergence with same shape q_logits.
weights: 1-D float tensor with shape [num_timesteps * batch_size].
Elements should be 1.0 only on end of sequences
Returns:
KL: float scalar.
"""
# For logistic regression
if FLAGS.num_classes == 2:
q = tf.nn.sigmoid(q_logits)
p = tf.nn.sigmoid(p_logits)
kl = (-tf.nn.sigmoid_cross_entropy_with_logits(logits=q_logits, labels=q) +
tf.nn.sigmoid_cross_entropy_with_logits(logits=p_logits, labels=q))
# For softmax regression
else:
q = tf.nn.softmax(q_logits)
p = tf.nn.softmax(p_logits)
kl = tf.reduce_sum(q * (tf.log(q) - tf.log(p)), 1)
num_labels = tf.reduce_sum(weights)
num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels)
loss = tf.identity(tf.reduce_sum(weights * kl) / num_labels, name='kl')
return loss
package(
default_visibility = [
"//adversarial_text:__subpackages__",
],
)
py_binary(
name = "gen_vocab",
srcs = ["gen_vocab.py"],
deps = [
":data_utils",
":document_generators",
],
)
py_binary(
name = "gen_data",
srcs = ["gen_data.py"],
deps = [
":data_utils",
":document_generators",
],
)
py_library(
name = "document_generators",
srcs = ["document_generators.py"],
)
py_library(
name = "data_utils",
srcs = ["data_utils.py"],
)
py_test(
name = "data_utils_test",
srcs = ["data_utils_test.py"],
deps = [
":data_utils",
],
)
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for generating/preprocessing data for adversarial text models."""
import operator
import os
import random
import re
import tensorflow as tf
EOS_TOKEN = '</s>'
# Data filenames
# Sequence Autoencoder
ALL_SA = 'all_sa.tfrecords'
TRAIN_SA = 'train_sa.tfrecords'
TEST_SA = 'test_sa.tfrecords'
# Language Model
ALL_LM = 'all_lm.tfrecords'
TRAIN_LM = 'train_lm.tfrecords'
TEST_LM = 'test_lm.tfrecords'
# Classification
TRAIN_CLASS = 'train_classification.tfrecords'
TEST_CLASS = 'test_classification.tfrecords'
VALID_CLASS = 'validate_classification.tfrecords'
# LM with bidirectional LSTM
TRAIN_REV_LM = 'train_reverse_lm.tfrecords'
TEST_REV_LM = 'test_reverse_lm.tfrecords'
# Classification with bidirectional LSTM
TRAIN_BD_CLASS = 'train_bidir_classification.tfrecords'
TEST_BD_CLASS = 'test_bidir_classification.tfrecords'
VALID_BD_CLASS = 'validate_bidir_classification.tfrecords'
class ShufflingTFRecordWriter(object):
"""Thin wrapper around TFRecordWriter that shuffles records."""
def __init__(self, path):
self._path = path
self._records = []
self._closed = False
def write(self, record):
assert not self._closed
self._records.append(record)
def close(self):
assert not self._closed
random.shuffle(self._records)
with tf.python_io.TFRecordWriter(self._path) as f:
for record in self._records:
f.write(record)
self._closed = True
def __enter__(self):
return self
def __exit__(self, unused_type, unused_value, unused_traceback):
self.close()
class Timestep(object):
"""Represents a single timestep in a SequenceWrapper."""
def __init__(self, token, label, weight, multivalent_tokens=False):
"""Constructs Timestep from empty Features."""
self._token = token
self._label = label
self._weight = weight
self._multivalent_tokens = multivalent_tokens
self._fill_with_defaults()
@property
def token(self):
if self._multivalent_tokens:
raise TypeError('Timestep may contain multiple values; use `tokens`')
return self._token.int64_list.value[0]
@property
def tokens(self):
return self._token.int64_list.value
@property
def label(self):
return self._label.int64_list.value[0]
@property
def weight(self):
return self._weight.float_list.value[0]
def set_token(self, token):
if self._multivalent_tokens:
raise TypeError('Timestep may contain multiple values; use `add_token`')
self._token.int64_list.value[0] = token
return self
def add_token(self, token):
self._token.int64_list.value.append(token)
return self
def set_label(self, label):
self._label.int64_list.value[0] = label
return self
def set_weight(self, weight):
self._weight.float_list.value[0] = weight
return self
def copy_from(self, timestep):
self.set_token(timestep.token).set_label(timestep.label).set_weight(
timestep.weight)
return self
def _fill_with_defaults(self):
if not self._multivalent_tokens:
self._token.int64_list.value.append(0)
self._label.int64_list.value.append(0)
self._weight.float_list.value.append(0.0)
class SequenceWrapper(object):
"""Wrapper around tf.SequenceExample."""
F_TOKEN_ID = 'token_id'
F_LABEL = 'label'
F_WEIGHT = 'weight'
def __init__(self, multivalent_tokens=False):
self._seq = tf.train.SequenceExample()
self._flist = self._seq.feature_lists.feature_list
self._timesteps = []
self._multivalent_tokens = multivalent_tokens
@property
def seq(self):
return self._seq
@property
def multivalent_tokens(self):
return self._multivalent_tokens
@property
def _tokens(self):
return self._flist[SequenceWrapper.F_TOKEN_ID].feature
@property
def _labels(self):
return self._flist[SequenceWrapper.F_LABEL].feature
@property
def _weights(self):
return self._flist[SequenceWrapper.F_WEIGHT].feature
def add_timestep(self):
timestep = Timestep(
self._tokens.add(),
self._labels.add(),
self._weights.add(),
multivalent_tokens=self._multivalent_tokens)
self._timesteps.append(timestep)
return timestep
def __iter__(self):
for timestep in self._timesteps:
yield timestep
def __len__(self):
return len(self._timesteps)
def __getitem__(self, idx):
return self._timesteps[idx]
def build_reverse_sequence(seq):
"""Builds a sequence that is the reverse of the input sequence."""
reverse_seq = SequenceWrapper()
# Copy all but last timestep
for timestep in reversed(seq[:-1]):
reverse_seq.add_timestep().copy_from(timestep)
# Copy final timestep
reverse_seq.add_timestep().copy_from(seq[-1])
return reverse_seq
def build_bidirectional_seq(seq, rev_seq):
bidir_seq = SequenceWrapper(multivalent_tokens=True)
for forward_ts, reverse_ts in zip(seq, rev_seq):
bidir_seq.add_timestep().add_token(forward_ts.token).add_token(
reverse_ts.token)
return bidir_seq
def build_lm_sequence(seq):
"""Builds language model sequence from input sequence.
Args:
seq: SequenceWrapper.
Returns:
SequenceWrapper with `seq` tokens copied over to output sequence tokens and
labels (offset by 1, i.e. predict next token) with weights set to 1.0.
"""
lm_seq = SequenceWrapper()
for i, timestep in enumerate(seq[:-1]):
lm_seq.add_timestep().set_token(timestep.token).set_label(
seq[i + 1].token).set_weight(1.0)
return lm_seq
def build_seq_ae_sequence(seq):
"""Builds seq_ae sequence from input sequence.
Args:
seq: SequenceWrapper.
Returns:
SequenceWrapper with `seq` inputs copied and concatenated, and with labels
copied in on the right-hand (i.e. decoder) side with weights set to 1.0.
The new sequence will have length `len(seq) * 2 - 1`, as the last timestep
of the encoder section and the first step of the decoder section will
overlap.
"""
seq_ae_seq = SequenceWrapper()
for i in range(len(seq) * 2 - 1):
ts = seq_ae_seq.add_timestep()
if i < len(seq) - 1:
# Encoder
ts.set_token(seq[i].token)
elif i == len(seq) - 1:
# Transition step
ts.set_token(seq[i].token)
ts.set_label(seq[0].token)
ts.set_weight(1.0)
else:
# Decoder
ts.set_token(seq[i % len(seq)].token)
ts.set_label(seq[(i + 1) % len(seq)].token)
ts.set_weight(1.0)
return seq_ae_seq
def build_labeled_sequence(seq, class_label, label_gain=False):
"""Builds labeled sequence from input sequence.
Args:
seq: SequenceWrapper.
class_label: bool.
label_gain: bool. If True, class_label will be put on every timestep and
weight will increase linearly from 0 to 1.
Returns:
SequenceWrapper with `seq` copied in and `class_label` added as label to
final timestep.
"""
label_seq = SequenceWrapper(multivalent_tokens=seq.multivalent_tokens)
# Copy sequence without labels
seq_len = len(seq)
final_timestep = None
for i, timestep in enumerate(seq):
label_timestep = label_seq.add_timestep()
if seq.multivalent_tokens:
for token in timestep.tokens:
label_timestep.add_token(token)
else:
label_timestep.set_token(timestep.token)
if label_gain:
label_timestep.set_label(int(class_label))
weight = 1.0 if seq_len < 2 else float(i) / (seq_len - 1)
label_timestep.set_weight(weight)
if i == (seq_len - 1):
final_timestep = label_timestep
# Edit final timestep to have class label and weight = 1.
final_timestep.set_label(int(class_label)).set_weight(1.0)
return label_seq
def split_by_punct(segment):
"""Splits str segment by punctuation, filters our empties and spaces."""
return [s for s in re.split(r'\W+', segment) if s and not s.isspace()]
def sort_vocab_by_frequency(vocab_freq_map):
"""Sorts vocab_freq_map by count.
Args:
vocab_freq_map: dict<str term, int count>, vocabulary terms with counts.
Returns:
list<tuple<str term, int count>> sorted by count, descending.
"""
return sorted(
vocab_freq_map.items(), key=operator.itemgetter(1), reverse=True)
def write_vocab_and_frequency(ordered_vocab_freqs, output_dir):
"""Writes ordered_vocab_freqs into vocab.txt and vocab_freq.txt."""
tf.gfile.MakeDirs(output_dir)
with open(os.path.join(output_dir, 'vocab.txt'), 'w') as vocab_f:
with open(os.path.join(output_dir, 'vocab_freq.txt'), 'w') as freq_f:
for word, freq in ordered_vocab_freqs:
vocab_f.write('{}\n'.format(word))
freq_f.write('{}\n'.format(freq))
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from adversarial_text.data import data_utils
data = data_utils
class SequenceWrapperTest(tf.test.TestCase):
def testDefaultTimesteps(self):
seq = data.SequenceWrapper()
t1 = seq.add_timestep()
_ = seq.add_timestep()
self.assertEqual(len(seq), 2)
self.assertEqual(t1.weight, 0.0)
self.assertEqual(t1.label, 0)
self.assertEqual(t1.token, 0)
def testSettersAndGetters(self):
ts = data.SequenceWrapper().add_timestep()
ts.set_token(3)
ts.set_label(4)
ts.set_weight(2.0)
self.assertEqual(ts.token, 3)
self.assertEqual(ts.label, 4)
self.assertEqual(ts.weight, 2.0)
def testTimestepIteration(self):
seq = data.SequenceWrapper()
seq.add_timestep().set_token(0)
seq.add_timestep().set_token(1)
seq.add_timestep().set_token(2)
for i, ts in enumerate(seq):
self.assertEqual(ts.token, i)
def testFillsSequenceExampleCorrectly(self):
seq = data.SequenceWrapper()
seq.add_timestep().set_token(1).set_label(2).set_weight(3.0)
seq.add_timestep().set_token(10).set_label(20).set_weight(30.0)
seq_ex = seq.seq
fl = seq_ex.feature_lists.feature_list
fl_token = fl[data.SequenceWrapper.F_TOKEN_ID].feature
fl_label = fl[data.SequenceWrapper.F_LABEL].feature
fl_weight = fl[data.SequenceWrapper.F_WEIGHT].feature
_ = [self.assertEqual(len(f), 2) for f in [fl_token, fl_label, fl_weight]]
self.assertAllEqual([f.int64_list.value[0] for f in fl_token], [1, 10])
self.assertAllEqual([f.int64_list.value[0] for f in fl_label], [2, 20])
self.assertAllEqual([f.float_list.value[0] for f in fl_weight], [3.0, 30.0])
class DataUtilsTest(tf.test.TestCase):
def testSplitByPunct(self):
output = data.split_by_punct(
'hello! world, i\'ve been\nwaiting\tfor\ryou for.a long time')
expected = [
'hello', 'world', 'i', 've', 'been', 'waiting', 'for', 'you', 'for',
'a', 'long', 'time'
]
self.assertListEqual(output, expected)
def _buildDummySequence(self):
seq = data.SequenceWrapper()
for i in range(10):
seq.add_timestep().set_token(i)
return seq
def testBuildLMSeq(self):
seq = self._buildDummySequence()
lm_seq = data.build_lm_sequence(seq)
for i, ts in enumerate(lm_seq):
self.assertEqual(ts.token, i)
self.assertEqual(ts.label, i + 1)
self.assertEqual(ts.weight, 1.0)
def testBuildSAESeq(self):
seq = self._buildDummySequence()
sa_seq = data.build_seq_ae_sequence(seq)
self.assertEqual(len(sa_seq), len(seq) * 2 - 1)
# Tokens should be sequence twice, minus the EOS token at the end
for i, ts in enumerate(sa_seq):
self.assertEqual(ts.token, seq[i % 10].token)
# Weights should be len-1 0.0's and len 1.0's.
for i in range(len(seq) - 1):
self.assertEqual(sa_seq[i].weight, 0.0)
for i in range(len(seq) - 1, len(sa_seq)):
self.assertEqual(sa_seq[i].weight, 1.0)
# Labels should be len-1 0's, and then the sequence
for i in range(len(seq) - 1):
self.assertEqual(sa_seq[i].label, 0)
for i in range(len(seq) - 1, len(sa_seq)):
self.assertEqual(sa_seq[i].label, seq[i - (len(seq) - 1)].token)
def testBuildLabelSeq(self):
seq = self._buildDummySequence()
eos_id = len(seq) - 1
label_seq = data.build_labeled_sequence(seq, True)
for i, ts in enumerate(label_seq[:-1]):
self.assertEqual(ts.token, i)
self.assertEqual(ts.label, 0)
self.assertEqual(ts.weight, 0.0)
final_timestep = label_seq[-1]
self.assertEqual(final_timestep.token, eos_id)
self.assertEqual(final_timestep.label, 1)
self.assertEqual(final_timestep.weight, 1.0)
def testBuildBidirLabelSeq(self):
seq = self._buildDummySequence()
reverse_seq = data.build_reverse_sequence(seq)
bidir_seq = data.build_bidirectional_seq(seq, reverse_seq)
label_seq = data.build_labeled_sequence(bidir_seq, True)
for (i, ts), j in zip(
enumerate(label_seq[:-1]), reversed(range(len(seq) - 1))):
self.assertAllEqual(ts.tokens, [i, j])
self.assertEqual(ts.label, 0)
self.assertEqual(ts.weight, 0.0)
final_timestep = label_seq[-1]
eos_id = len(seq) - 1
self.assertAllEqual(final_timestep.tokens, [eos_id, eos_id])
self.assertEqual(final_timestep.label, 1)
self.assertEqual(final_timestep.weight, 1.0)
def testReverseSeq(self):
seq = self._buildDummySequence()
reverse_seq = data.build_reverse_sequence(seq)
for i, ts in enumerate(reversed(reverse_seq[:-1])):
self.assertEqual(ts.token, i)
self.assertEqual(ts.label, 0)
self.assertEqual(ts.weight, 0.0)
final_timestep = reverse_seq[-1]
eos_id = len(seq) - 1
self.assertEqual(final_timestep.token, eos_id)
self.assertEqual(final_timestep.label, 0)
self.assertEqual(final_timestep.weight, 0.0)
def testBidirSeq(self):
seq = self._buildDummySequence()
reverse_seq = data.build_reverse_sequence(seq)
bidir_seq = data.build_bidirectional_seq(seq, reverse_seq)
for (i, ts), j in zip(
enumerate(bidir_seq[:-1]), reversed(range(len(seq) - 1))):
self.assertAllEqual(ts.tokens, [i, j])
self.assertEqual(ts.label, 0)
self.assertEqual(ts.weight, 0.0)
final_timestep = bidir_seq[-1]
eos_id = len(seq) - 1
self.assertAllEqual(final_timestep.tokens, [eos_id, eos_id])
self.assertEqual(final_timestep.label, 0)
self.assertEqual(final_timestep.weight, 0.0)
def testLabelGain(self):
seq = self._buildDummySequence()
label_seq = data.build_labeled_sequence(seq, True, label_gain=True)
for i, ts in enumerate(label_seq):
self.assertEqual(ts.token, i)
self.assertEqual(ts.label, 1)
self.assertNear(ts.weight, float(i) / (len(seq) - 1), 1e-3)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input readers and document/token generators for datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import csv
import os
import random
import tensorflow as tf
from adversarial_text.data import data_utils
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('dataset', '', 'Which dataset to generate data for')
# Preprocessing config
flags.DEFINE_boolean('output_unigrams', True, 'Whether to output unigrams.')
flags.DEFINE_boolean('output_bigrams', False, 'Whether to output bigrams.')
flags.DEFINE_boolean('output_char', False, 'Whether to output characters.')
flags.DEFINE_boolean('lowercase', True, 'Whether to lowercase document terms.')
# IMDB
flags.DEFINE_string('imdb_input_dir', '', 'The input directory containing the '
'IMDB sentiment dataset.')
flags.DEFINE_integer('imdb_validation_pos_start_id', 10621, 'File id of the '
'first file in the pos sentiment validation set.')
flags.DEFINE_integer('imdb_validation_neg_start_id', 10625, 'File id of the '
'first file in the neg sentiment validation set.')
# DBpedia
flags.DEFINE_string('dbpedia_input_dir', '',
'Path to DBpedia directory containing train.csv and '
'test.csv.')
# Reuters Corpus (rcv1)
flags.DEFINE_string('rcv1_input_dir', '',
'Path to rcv1 directory containing train.csv, unlab.csv, '
'and test.csv.')
# Rotten Tomatoes
flags.DEFINE_string('rt_input_dir', '',
'The Rotten Tomatoes dataset input directory.')
# The amazon reviews input file to use in either the RT or IMDB datasets.
flags.DEFINE_string('amazon_unlabeled_input_file', '',
'The unlabeled Amazon Reviews dataset input file. If set, '
'the input file is used to augment RT and IMDB vocab.')
Document = namedtuple('Document',
'content is_validation is_test label add_tokens')
def documents(dataset='train',
include_unlabeled=False,
include_validation=False):
"""Generates Documents based on FLAGS.dataset.
Args:
dataset: str, identifies folder within IMDB data directory, test or train.
include_unlabeled: bool, whether to include the unsup directory. Only valid
when dataset=train.
include_validation: bool, whether to include validation data.
Yields:
Document
Raises:
ValueError: if include_unlabeled is true but dataset is not 'train'
"""
if include_unlabeled and dataset != 'train':
raise ValueError('If include_unlabeled=True, must use train dataset')
# Set the random seed so that we have the same validation set when running
# gen_data and gen_vocab.
random.seed(302)
ds = FLAGS.dataset
if ds == 'imdb':
docs_gen = imdb_documents
elif ds == 'dbpedia':
docs_gen = dbpedia_documents
elif ds == 'rcv1':
docs_gen = rcv1_documents
elif ds == 'rt':
docs_gen = rt_documents
else:
raise ValueError('Unrecognized dataset %s' % FLAGS.dataset)
for doc in docs_gen(dataset, include_unlabeled, include_validation):
yield doc
def tokens(doc):
"""Given a Document, produces character or word tokens.
Tokens can be either characters, or word-level tokens (unigrams and/or
bigrams).
Args:
doc: Document to produce tokens from.
Yields:
token
Raises:
ValueError: if all FLAGS.{output_unigrams, output_bigrams, output_char}
are False.
"""
if not (FLAGS.output_unigrams or FLAGS.output_bigrams or FLAGS.output_char):
raise ValueError(
'At least one of {FLAGS.output_unigrams, FLAGS.output_bigrams, '
'FLAGS.output_char} must be true')
content = doc.content.strip()
if FLAGS.lowercase:
content = content.lower()
if FLAGS.output_char:
for char in content:
yield char
else:
tokens_ = data_utils.split_by_punct(content)
for i, token in enumerate(tokens_):
if FLAGS.output_unigrams:
yield token
if FLAGS.output_bigrams:
previous_token = (tokens_[i - 1] if i > 0 else data_utils.EOS_TOKEN)
bigram = '_'.join([previous_token, token])
yield bigram
if (i + 1) == len(tokens_):
bigram = '_'.join([token, data_utils.EOS_TOKEN])
yield bigram
def imdb_documents(dataset='train',
include_unlabeled=False,
include_validation=False):
"""Generates Documents for IMDB dataset.
Data from http://ai.stanford.edu/~amaas/data/sentiment/
Args:
dataset: str, identifies folder within IMDB data directory, test or train.
include_unlabeled: bool, whether to include the unsup directory. Only valid
when dataset=train.
include_validation: bool, whether to include validation data.
Yields:
Document
Raises:
ValueError: if FLAGS.imdb_input_dir is empty.
"""
if not FLAGS.imdb_input_dir:
raise ValueError('Must provide FLAGS.imdb_input_dir')
tf.logging.info('Generating IMDB documents...')
def check_is_validation(filename, class_label):
if class_label is None:
return False
file_idx = int(filename.split('_')[0])
is_pos_valid = (class_label and
file_idx >= FLAGS.imdb_validation_pos_start_id)
is_neg_valid = (not class_label and
file_idx >= FLAGS.imdb_validation_neg_start_id)
return is_pos_valid or is_neg_valid
dirs = [(dataset + '/pos', True), (dataset + '/neg', False)]
if include_unlabeled:
dirs.append(('train/unsup', None))
for d, class_label in dirs:
for filename in os.listdir(os.path.join(FLAGS.imdb_input_dir, d)):
is_validation = check_is_validation(filename, class_label)
if is_validation and not include_validation:
continue
with open(os.path.join(FLAGS.imdb_input_dir, d, filename)) as imdb_f:
content = imdb_f.read()
yield Document(
content=content,
is_validation=is_validation,
is_test=False,
label=class_label,
add_tokens=True)
if FLAGS.amazon_unlabeled_input_file and include_unlabeled:
with open(FLAGS.amazon_unlabeled_input_file) as rt_f:
for content in rt_f:
yield Document(content=content, is_validation=False, is_test=False,
label=None, add_tokens=False)
def dbpedia_documents(dataset='train',
include_unlabeled=False,
include_validation=False):
"""Generates Documents for DBpedia dataset.
Dataset linked to at https://github.com/zhangxiangxiao/Crepe.
Args:
dataset: str, identifies the csv file within the DBpedia data directory,
test or train.
include_unlabeled: bool, unused.
include_validation: bool, whether to include validation data, which is a
randomly selected 10% of the data.
Yields:
Document
Raises:
ValueError: if FLAGS.dbpedia_input_dir is empty.
"""
del include_unlabeled
if not FLAGS.dbpedia_input_dir:
raise ValueError('Must provide FLAGS.dbpedia_input_dir')
tf.logging.info('Generating DBpedia documents...')
with open(os.path.join(FLAGS.dbpedia_input_dir, dataset + '.csv')) as db_f:
reader = csv.reader(db_f)
for row in reader:
# 10% of the data is randomly held out
is_validation = random.randint(1, 10) == 1
if is_validation and not include_validation:
continue
content = row[1] + ' ' + row[2]
yield Document(
content=content,
is_validation=is_validation,
is_test=False,
label=int(row[0]),
add_tokens=True)
def rcv1_documents(dataset='train',
include_unlabeled=True,
include_validation=False):
# pylint:disable=line-too-long
"""Generates Documents for Reuters Corpus (rcv1) dataset.
Dataset described at http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/lyrl2004_rcv1v2_README.htm
Args:
dataset: str, identifies the csv file within the rcv1 data directory.
include_unlabeled: bool, whether to include the unlab file. Only valid
when dataset=train.
include_validation: bool, whether to include validation data, which is a
randomly selected 10% of the data.
Yields:
Document
Raises:
ValueError: if FLAGS.rcv1_input_dir is empty.
"""
# pylint:enable=line-too-long
if not FLAGS.rcv1_input_dir:
raise ValueError('Must provide FLAGS.rcv1_input_dir')
tf.logging.info('Generating rcv1 documents...')
datasets = [dataset]
if include_unlabeled:
if dataset == 'train':
datasets.append('unlab')
for dset in datasets:
with open(os.path.join(FLAGS.rcv1_input_dir, dset + '.csv')) as db_f:
reader = csv.reader(db_f)
for row in reader:
# 10% of the data is randomly held out
is_validation = random.randint(1, 10) == 1
if is_validation and not include_validation:
continue
content = row[1]
yield Document(
content=content,
is_validation=is_validation,
is_test=False,
label=int(row[0]),
add_tokens=True)
def rt_documents(dataset='train',
include_unlabeled=True,
include_validation=False):
# pylint:disable=line-too-long
"""Generates Documents for the Rotten Tomatoes dataset.
Dataset available at http://www.cs.cornell.edu/people/pabo/movie-review-data/
In this dataset, amazon reviews are used for the unlabeled data.
Args:
dataset: str, identifies the data subdirectory.
include_unlabeled: bool, whether to include the unlabeled data. Only valid
when dataset=train.
include_validation: bool, whether to include validation data, which is a
randomly selected 10% of the data.
Yields:
Document
Raises:
ValueError: if FLAGS.rt_input_dir is empty.
"""
# pylint:enable=line-too-long
if not FLAGS.rt_input_dir:
raise ValueError('Must provide FLAGS.rt_input_dir')
tf.logging.info('Generating rt documents...')
data_files = []
input_filenames = os.listdir(FLAGS.rt_input_dir)
for inp_fname in input_filenames:
if inp_fname.endswith('.pos'):
data_files.append((os.path.join(FLAGS.rt_input_dir, inp_fname), True))
elif inp_fname.endswith('.neg'):
data_files.append((os.path.join(FLAGS.rt_input_dir, inp_fname), False))
if include_unlabeled and FLAGS.amazon_unlabeled_input_file:
data_files.append((FLAGS.amazon_unlabeled_input_file, None))
for filename, class_label in data_files:
with open(filename) as rt_f:
for content in rt_f:
if class_label is None:
# Process Amazon Review data for unlabeled dataset
if content.startswith('review/text'):
yield Document(content=content, is_validation=False,
is_test=False, label=None, add_tokens=False)
else:
# 10% of the data is randomly held out for the validation set and
# another 10% of it is randomly held out for the test set
random_int = random.randint(1, 10)
is_validation = random_int == 1
is_test = random_int == 2
if (is_test and dataset != 'test') or (
is_validation and not include_validation):
continue
yield Document(content=content, is_validation=is_validation,
is_test=is_test, label=class_label, add_tokens=True)
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Create TFRecord files of SequenceExample protos from dataset.
Constructs 3 datasets:
1. Labeled data for the LSTM classification model, optionally with label gain.
"*_classification.tfrecords" (for both unidirectional and bidirectional
models).
2. Data for the unsupervised LM-LSTM model that predicts the next token.
"*_lm.tfrecords" (generates forward and reverse data).
3. Data for the unsupervised SA-LSTM model that uses Seq2Seq.
"*_sa.tfrecords".
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import string
import tensorflow as tf
from adversarial_text.data import data_utils
from adversarial_text.data import document_generators
data = data_utils
flags = tf.app.flags
FLAGS = flags.FLAGS
# Flags for input data are in document_generators.py
flags.DEFINE_string('vocab_file', '', 'Path to the vocabulary file. Defaults '
'to FLAGS.output_dir/vocab.txt.')
flags.DEFINE_string('output_dir', '', 'Path to save tfrecords.')
# Config
flags.DEFINE_boolean('label_gain', False,
'Enable linear label gain. If True, sentiment label will '
'be included at each timestep with linear weight '
'increase.')
def build_shuffling_tf_record_writer(fname):
return data.ShufflingTFRecordWriter(os.path.join(FLAGS.output_dir, fname))
def build_tf_record_writer(fname):
return tf.python_io.TFRecordWriter(os.path.join(FLAGS.output_dir, fname))
def build_input_sequence(doc, vocab_ids):
"""Builds input sequence from file.
Splits lines on whitespace. Treats punctuation as whitespace. For word-level
sequences, only keeps terms that are in the vocab.
Terms are added as token in the SequenceExample. The EOS_TOKEN is also
appended. Label and weight features are set to 0.
Args:
doc: Document (defined in `document_generators`) from which to build the
sequence.
vocab_ids: dict<term, id>.
Returns:
SequenceExampleWrapper.
"""
seq = data.SequenceWrapper()
for token in document_generators.tokens(doc):
if token in vocab_ids:
seq.add_timestep().set_token(vocab_ids[token])
# Add EOS token to end
seq.add_timestep().set_token(vocab_ids[data.EOS_TOKEN])
return seq
def make_vocab_ids(vocab_filename):
if FLAGS.output_char:
ret = dict([(char, i) for i, char in enumerate(string.printable)])
ret[data.EOS_TOKEN] = len(string.printable)
return ret
else:
with open(vocab_filename) as vocab_f:
return dict([(line.strip(), i) for i, line in enumerate(vocab_f)])
def generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all):
"""Generates training data."""
# Construct training data writers
writer_lm = build_shuffling_tf_record_writer(data.TRAIN_LM)
writer_seq_ae = build_shuffling_tf_record_writer(data.TRAIN_SA)
writer_class = build_shuffling_tf_record_writer(data.TRAIN_CLASS)
writer_valid_class = build_tf_record_writer(data.VALID_CLASS)
writer_rev_lm = build_shuffling_tf_record_writer(data.TRAIN_REV_LM)
writer_bd_class = build_shuffling_tf_record_writer(data.TRAIN_BD_CLASS)
writer_bd_valid_class = build_shuffling_tf_record_writer(data.VALID_BD_CLASS)
for doc in document_generators.documents(
dataset='train', include_unlabeled=True, include_validation=True):
input_seq = build_input_sequence(doc, vocab_ids)
if len(input_seq) < 2:
continue
rev_seq = data.build_reverse_sequence(input_seq)
lm_seq = data.build_lm_sequence(input_seq)
rev_lm_seq = data.build_lm_sequence(rev_seq)
seq_ae_seq = data.build_seq_ae_sequence(input_seq)
if doc.label is not None:
# Used for sentiment classification.
label_seq = data.build_labeled_sequence(
input_seq,
doc.label,
label_gain=(FLAGS.label_gain and not doc.is_validation))
bd_label_seq = data.build_labeled_sequence(
data.build_bidirectional_seq(input_seq, rev_seq),
doc.label,
label_gain=(FLAGS.label_gain and not doc.is_validation))
class_writer = writer_valid_class if doc.is_validation else writer_class
bd_class_writer = (writer_bd_valid_class
if doc.is_validation else writer_bd_class)
class_writer.write(label_seq.seq.SerializeToString())
bd_class_writer.write(bd_label_seq.seq.SerializeToString())
# Write
lm_seq_ser = lm_seq.seq.SerializeToString()
seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString()
writer_lm_all.write(lm_seq_ser)
writer_seq_ae_all.write(seq_ae_seq_ser)
if not doc.is_validation:
writer_lm.write(lm_seq_ser)
writer_rev_lm.write(rev_lm_seq.seq.SerializeToString())
writer_seq_ae.write(seq_ae_seq_ser)
# Close writers
writer_lm.close()
writer_seq_ae.close()
writer_class.close()
writer_valid_class.close()
writer_rev_lm.close()
writer_bd_class.close()
writer_bd_valid_class.close()
def generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all):
"""Generates test data."""
# Construct test data writers
writer_lm = build_shuffling_tf_record_writer(data.TEST_LM)
writer_rev_lm = build_shuffling_tf_record_writer(data.TEST_REV_LM)
writer_seq_ae = build_shuffling_tf_record_writer(data.TEST_SA)
writer_class = build_tf_record_writer(data.TEST_CLASS)
writer_bd_class = build_shuffling_tf_record_writer(data.TEST_BD_CLASS)
for doc in document_generators.documents(
dataset='test', include_unlabeled=False, include_validation=True):
input_seq = build_input_sequence(doc, vocab_ids)
if len(input_seq) < 2:
continue
rev_seq = data.build_reverse_sequence(input_seq)
lm_seq = data.build_lm_sequence(input_seq)
rev_lm_seq = data.build_lm_sequence(rev_seq)
seq_ae_seq = data.build_seq_ae_sequence(input_seq)
label_seq = data.build_labeled_sequence(input_seq, doc.label)
bd_label_seq = data.build_labeled_sequence(
data.build_bidirectional_seq(input_seq, rev_seq), doc.label)
# Write
writer_class.write(label_seq.seq.SerializeToString())
writer_bd_class.write(bd_label_seq.seq.SerializeToString())
lm_seq_ser = lm_seq.seq.SerializeToString()
seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString()
writer_lm.write(lm_seq_ser)
writer_rev_lm.write(rev_lm_seq.seq.SerializeToString())
writer_seq_ae.write(seq_ae_seq_ser)
writer_lm_all.write(lm_seq_ser)
writer_seq_ae_all.write(seq_ae_seq_ser)
# Close test writers
writer_lm.close()
writer_rev_lm.close()
writer_seq_ae.close()
writer_class.close()
writer_bd_class.close()
def main(_):
tf.logging.info('Assigning vocabulary ids...')
vocab_ids = make_vocab_ids(
FLAGS.vocab_file or os.path.join(FLAGS.output_dir, 'vocab.txt'))
with build_shuffling_tf_record_writer(data.ALL_LM) as writer_lm_all:
with build_shuffling_tf_record_writer(data.ALL_SA) as writer_seq_ae_all:
tf.logging.info('Generating training data...')
generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all)
tf.logging.info('Generating test data...')
generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generates vocabulary and term frequency files for datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import defaultdict
import tensorflow as tf
from adversarial_text.data import data_utils
from adversarial_text.data import document_generators
flags = tf.app.flags
FLAGS = flags.FLAGS
# Flags controlling input are in document_generators.py
flags.DEFINE_string('output_dir', '',
'Path to save vocab.txt and vocab_freq.txt.')
flags.DEFINE_boolean('use_unlabeled', True, 'Whether to use the '
'unlabeled sentiment dataset in the vocabulary.')
flags.DEFINE_boolean('include_validation', False, 'Whether to include the '
'validation set in the vocabulary.')
flags.DEFINE_integer('doc_count_threshold', 1, 'The minimum number of '
'documents a word or bigram should occur in to keep '
'it in the vocabulary.')
MAX_VOCAB_SIZE = 100 * 1000
def fill_vocab_from_doc(doc, vocab_freqs, doc_counts):
"""Fills vocabulary and doc counts with tokens from doc.
Args:
doc: Document to read tokens from.
vocab_freqs: dict<token, frequency count>
doc_counts: dict<token, document count>
Returns:
None
"""
doc_seen = set()
for token in document_generators.tokens(doc):
if doc.add_tokens or token in vocab_freqs:
vocab_freqs[token] += 1
if token not in doc_seen:
doc_counts[token] += 1
doc_seen.add(token)
def main(_):
vocab_freqs = defaultdict(int)
doc_counts = defaultdict(int)
# Fill vocabulary frequencies map and document counts map
for doc in document_generators.documents(
dataset='train',
include_unlabeled=FLAGS.use_unlabeled,
include_validation=FLAGS.include_validation):
fill_vocab_from_doc(doc, vocab_freqs, doc_counts)
# Filter out low-occurring terms
vocab_freqs = dict((term, freq) for term, freq in vocab_freqs.iteritems()
if doc_counts[term] > FLAGS.doc_count_threshold)
# Sort by frequency
ordered_vocab_freqs = data_utils.sort_vocab_by_frequency(vocab_freqs)
# Limit vocab size
ordered_vocab_freqs = ordered_vocab_freqs[:MAX_VOCAB_SIZE]
# Add EOS token
ordered_vocab_freqs.append((data_utils.EOS_TOKEN, 1))
# Write
tf.gfile.MakeDirs(FLAGS.output_dir)
data_utils.write_vocab_and_frequency(ordered_vocab_freqs, FLAGS.output_dir)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluates text classification model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import time
import tensorflow as tf
import graphs
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '',
'BNS name prefix of the Tensorflow eval master, '
'or "local".')
flags.DEFINE_string('eval_dir', '/tmp/text_eval',
'Directory where to write event logs.')
flags.DEFINE_string('eval_data', 'test', 'Specify which dataset is used. '
'("train", "valid", "test") ')
flags.DEFINE_string('checkpoint_dir', '/tmp/text_train',
'Directory where to read model checkpoints.')
flags.DEFINE_integer('eval_interval_secs', 60, 'How often to run the eval.')
flags.DEFINE_integer('num_examples', 32, 'Number of examples to run.')
flags.DEFINE_bool('run_once', False, 'Whether to run eval only once.')
def restore_from_checkpoint(sess, saver):
"""Restore model from checkpoint.
Args:
sess: Session.
saver: Saver for restoring the checkpoint.
Returns:
bool: Whether the checkpoint was found and restored
"""
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if not ckpt or not ckpt.model_checkpoint_path:
tf.logging.info('No checkpoint found at %s', FLAGS.checkpoint_dir)
return False
saver.restore(sess, ckpt.model_checkpoint_path)
return True
def run_eval(eval_ops, summary_writer, saver):
"""Runs evaluation over FLAGS.num_examples examples.
Args:
eval_ops: dict<metric name, tuple(value, update_op)>
summary_writer: Summary writer.
saver: Saver.
Returns:
dict<metric name, value>, with value being the average over all examples.
"""
sv = tf.train.Supervisor(logdir=FLAGS.eval_dir, saver=None, summary_op=None)
with sv.managed_session(
master=FLAGS.master, start_standard_services=False) as sess:
if not restore_from_checkpoint(sess, saver):
return
sv.start_queue_runners(sess)
metric_names, ops = zip(*eval_ops.items())
value_ops, update_ops = zip(*ops)
# Run update ops
num_batches = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
tf.logging.info('Running %d batches for evaluation.', num_batches)
for i in range(num_batches):
if (i + 1) % 10 == 0:
tf.logging.info('Running batch %d/%d...', i + 1, num_batches)
sess.run(update_ops)
values = sess.run(value_ops)
metric_values = dict(zip(metric_names, values))
tf.logging.info('Eval metric values:')
summary = tf.summary.Summary()
for name, val in metric_values.items():
summary.value.add(tag=name, simple_value=val)
tf.logging.info('%s = %.3f', name, val)
global_step_val = sess.run(tf.train.get_global_step())
summary_writer.add_summary(summary, global_step_val)
return metric_values
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
tf.gfile.MakeDirs(FLAGS.eval_dir)
tf.logging.info('Building eval graph...')
output = graphs.get_model().eval_graph(FLAGS.eval_data)
eval_ops, moving_averaged_variables = output
saver = tf.train.Saver(moving_averaged_variables)
summary_writer = tf.summary.FileWriter(
FLAGS.eval_dir, graph=tf.get_default_graph())
while True:
run_eval(eval_ops, summary_writer, saver)
if FLAGS.run_once:
break
time.sleep(FLAGS.eval_interval_secs)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Virtual adversarial text models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import csv
import os
import tensorflow as tf
import adversarial_losses as adv_lib
import inputs as inputs_lib
import layers as layers_lib
flags = tf.app.flags
FLAGS = flags.FLAGS
# Flags governing adversarial training are defined in adversarial_losses.py.
# Classifier
flags.DEFINE_integer('num_classes', 2, 'Number of classes for classification')
# Data path
flags.DEFINE_string('data_dir', '/tmp/IMDB',
'Directory path to preprocessed text dataset.')
flags.DEFINE_string('vocab_freq_path', None,
'Path to pre-calculated vocab frequency data. If '
'None, use FLAGS.data_dir/vocab_freq.txt.')
flags.DEFINE_integer('batch_size', 64, 'Size of the batch.')
flags.DEFINE_integer('num_timesteps', 100, 'Number of timesteps for BPTT')
# Model architechture
flags.DEFINE_bool('bidir_lstm', False, 'Whether to build a bidirectional LSTM.')
flags.DEFINE_integer('rnn_num_layers', 1, 'Number of LSTM layers.')
flags.DEFINE_integer('rnn_cell_size', 512,
'Number of hidden units in the LSTM.')
flags.DEFINE_integer('cl_num_layers', 1,
'Number of hidden layers of classification model.')
flags.DEFINE_integer('cl_hidden_size', 30,
'Number of hidden units in classification layer.')
flags.DEFINE_integer('num_candidate_samples', -1,
'Num samples used in the sampled output layer.')
flags.DEFINE_bool('use_seq2seq_autoencoder', False,
'If True, seq2seq auto-encoder is used to pretrain. '
'If False, standard language model is used.')
# Vocabulary and embeddings
flags.DEFINE_integer('embedding_dims', 256, 'Dimensions of embedded vector.')
flags.DEFINE_integer('vocab_size', 86934,
'The size of the vocaburary. This value '
'should be exactly same as the number of the '
'vocabulary used in dataset. Because the last '
'indexed vocabulary of the dataset preprocessed by '
'my preprocessed code, is always <eos> and here we '
'specify the <eos> with the the index.')
flags.DEFINE_bool('normalize_embeddings', True,
'Normalize word embeddings by vocab frequency')
# Optimization
flags.DEFINE_float('learning_rate', 0.001, 'Learning rate while fine-tuning.')
flags.DEFINE_float('learning_rate_decay_factor', 1.0,
'Learning rate decay factor')
flags.DEFINE_boolean('sync_replicas', False, 'sync_replica or not')
flags.DEFINE_integer('replicas_to_aggregate', 1,
'The number of replicas to aggregate')
# Regularization
flags.DEFINE_float('max_grad_norm', 1.0,
'Clip the global gradient norm to this value.')
flags.DEFINE_float('keep_prob_emb', 1.0, 'keep probability on embedding layer')
flags.DEFINE_float('keep_prob_lstm_out', 1.0,
'keep probability on lstm output.')
flags.DEFINE_float('keep_prob_cl_hidden', 1.0,
'keep probability on classification hidden layer')
def get_model():
if FLAGS.bidir_lstm:
return VatxtBidirModel()
else:
return VatxtModel()
class VatxtModel(object):
"""Constructs training and evaluation graphs.
Main methods: `classifier_training()`, `language_model_training()`,
and `eval_graph()`.
Variable reuse is a critical part of the model, both for sharing variables
between the language model and the classifier, and for reusing variables for
the adversarial loss calculation. To ensure correct variable reuse, all
variables are created in Keras-style layers, wherein stateful layers (i.e.
layers with variables) are represented as callable instances of the Layer
class. Each time the Layer instance is called, it is using the same variables.
All Layers are constructed in the __init__ method and reused in the various
graph-building functions.
"""
def __init__(self, cl_logits_input_dim=None):
self.global_step = tf.contrib.framework.get_or_create_global_step()
self.vocab_freqs = _get_vocab_freqs()
# Cache VatxtInput objects
self.cl_inputs = None
self.lm_inputs = None
# Cache intermediate Tensors that are reused
self.tensors = {}
# Construct layers which are reused in constructing the LM and
# Classification graphs. Instantiating them all once here ensures that
# variable reuse works correctly.
self.layers = {}
self.layers['embedding'] = layers_lib.Embedding(
FLAGS.vocab_size, FLAGS.embedding_dims, FLAGS.normalize_embeddings,
self.vocab_freqs, FLAGS.keep_prob_emb)
self.layers['lstm'] = layers_lib.LSTM(
FLAGS.rnn_cell_size, FLAGS.rnn_num_layers, FLAGS.keep_prob_lstm_out)
self.layers['lm_loss'] = layers_lib.SoftmaxLoss(
FLAGS.vocab_size,
FLAGS.num_candidate_samples,
self.vocab_freqs,
name='LM_loss')
cl_logits_input_dim = cl_logits_input_dim or FLAGS.rnn_cell_size
self.layers['cl_logits'] = layers_lib.cl_logits_subgraph(
[FLAGS.cl_hidden_size] * FLAGS.cl_num_layers, cl_logits_input_dim,
FLAGS.num_classes, FLAGS.keep_prob_cl_hidden)
@property
def pretrained_variables(self):
return (self.layers['embedding'].trainable_weights +
self.layers['lstm'].trainable_weights)
def classifier_training(self):
loss = self.classifier_graph()
train_op = optimize(loss, self.global_step)
return train_op, loss, self.global_step
def language_model_training(self):
loss = self.language_model_graph()
train_op = optimize(loss, self.global_step)
return train_op, loss, self.global_step
def classifier_graph(self):
"""Constructs classifier graph from inputs to classifier loss.
* Caches the VatxtInput object in `self.cl_inputs`
* Caches tensors: `cl_embedded`, `cl_logits`, `cl_loss`
Returns:
loss: scalar float.
"""
inputs = _inputs('train', pretrain=False)
self.cl_inputs = inputs
embedded = self.layers['embedding'](inputs.tokens)
self.tensors['cl_embedded'] = embedded
_, next_state, logits, loss = self.cl_loss_from_embedding(
embedded, return_intermediates=True)
tf.summary.scalar('classification_loss', loss)
self.tensors['cl_logits'] = logits
self.tensors['cl_loss'] = loss
acc = layers_lib.accuracy(logits, inputs.labels, inputs.weights)
tf.summary.scalar('accuracy', acc)
adv_loss = (self.adversarial_loss() * tf.constant(
FLAGS.adv_reg_coeff, name='adv_reg_coeff'))
tf.summary.scalar('adversarial_loss', adv_loss)
total_loss = loss + adv_loss
tf.summary.scalar('total_classification_loss', total_loss)
with tf.control_dependencies([inputs.save_state(next_state)]):
total_loss = tf.identity(total_loss)
return total_loss
def language_model_graph(self, compute_loss=True):
"""Constructs LM graph from inputs to LM loss.
* Caches the VatxtInput object in `self.lm_inputs`
* Caches tensors: `lm_embedded`
Args:
compute_loss: bool, whether to compute and return the loss or stop after
the LSTM computation.
Returns:
loss: scalar float.
"""
inputs = _inputs('train', pretrain=True)
self.lm_inputs = inputs
return self._lm_loss(inputs, compute_loss=compute_loss)
def _lm_loss(self,
inputs,
emb_key='lm_embedded',
lstm_layer='lstm',
lm_loss_layer='lm_loss',
loss_name='lm_loss',
compute_loss=True):
embedded = self.layers['embedding'](inputs.tokens)
self.tensors[emb_key] = embedded
lstm_out, next_state = self.layers[lstm_layer](embedded, inputs.state,
inputs.length)
if compute_loss:
loss = self.layers[lm_loss_layer](
[lstm_out, inputs.labels, inputs.weights])
with tf.control_dependencies([inputs.save_state(next_state)]):
loss = tf.identity(loss)
tf.summary.scalar(loss_name, loss)
return loss
def eval_graph(self, dataset='test'):
"""Constructs classifier evaluation graph.
Args:
dataset: the labeled dataset to evaluate, {'train', 'test', 'valid'}.
Returns:
eval_ops: dict<metric name, tuple(value, update_op)>
var_restore_dict: dict mapping variable restoration names to variables.
Trainable variables will be mapped to their moving average names.
"""
inputs = _inputs(dataset, pretrain=False)
embedded = self.layers['embedding'](inputs.tokens)
_, next_state, logits, _ = self.cl_loss_from_embedding(
embedded, inputs=inputs, return_intermediates=True)
eval_ops = {
'accuracy':
tf.contrib.metrics.streaming_accuracy(
layers_lib.predictions(logits), inputs.labels,
inputs.weights)
}
with tf.control_dependencies([inputs.save_state(next_state)]):
acc, acc_update = eval_ops['accuracy']
acc_update = tf.identity(acc_update)
eval_ops['accuracy'] = (acc, acc_update)
var_restore_dict = make_restore_average_vars_dict()
return eval_ops, var_restore_dict
def cl_loss_from_embedding(self,
embedded,
inputs=None,
return_intermediates=False):
"""Compute classification loss from embedding.
Args:
embedded: 3-D float Tensor [batch_size, num_timesteps, embedding_dim]
inputs: VatxtInput, defaults to self.cl_inputs.
return_intermediates: bool, whether to return intermediate tensors or only
the final loss.
Returns:
If return_intermediates is True:
lstm_out, next_state, logits, loss
Else:
loss
"""
if inputs is None:
inputs = self.cl_inputs
lstm_out, next_state = self.layers['lstm'](embedded, inputs.state,
inputs.length)
logits = self.layers['cl_logits'](lstm_out)
loss = layers_lib.classification_loss(logits, inputs.labels, inputs.weights)
if return_intermediates:
return lstm_out, next_state, logits, loss
else:
return loss
def adversarial_loss(self):
"""Compute adversarial loss based on FLAGS.adv_training_method."""
def random_perturbation_loss():
return adv_lib.random_perturbation_loss(self.tensors['cl_embedded'],
self.cl_inputs.length,
self.cl_loss_from_embedding)
def adversarial_loss():
return adv_lib.adversarial_loss(self.tensors['cl_embedded'],
self.tensors['cl_loss'],
self.cl_loss_from_embedding)
def virtual_adversarial_loss():
"""Computes virtual adversarial loss.
Uses lm_inputs and constructs the language model graph if it hasn't yet
been constructed.
Also ensures that the LM input states are saved for LSTM state-saving
BPTT.
Returns:
loss: float scalar.
"""
if self.lm_inputs is None:
self.language_model_graph(compute_loss=False)
def logits_from_embedding(embedded, return_next_state=False):
_, next_state, logits, _ = self.cl_loss_from_embedding(
embedded, inputs=self.lm_inputs, return_intermediates=True)
if return_next_state:
return next_state, logits
else:
return logits
next_state, lm_cl_logits = logits_from_embedding(
self.tensors['lm_embedded'], return_next_state=True)
va_loss = adv_lib.virtual_adversarial_loss(
lm_cl_logits, self.tensors['lm_embedded'], self.lm_inputs,
logits_from_embedding)
with tf.control_dependencies([self.lm_inputs.save_state(next_state)]):
va_loss = tf.identity(va_loss)
return va_loss
def combo_loss():
return adversarial_loss() + virtual_adversarial_loss()
adv_training_methods = {
# Random perturbation
'rp': random_perturbation_loss,
# Adversarial training
'at': adversarial_loss,
# Virtual adversarial training
'vat': virtual_adversarial_loss,
# Both at and vat
'atvat': combo_loss,
'': lambda: tf.constant(0.),
None: lambda: tf.constant(0.),
}
with tf.name_scope('adversarial_loss'):
return adv_training_methods[FLAGS.adv_training_method]()
class VatxtBidirModel(VatxtModel):
"""Extension of VatxtModel that supports bidirectional input."""
def __init__(self):
super(VatxtBidirModel,
self).__init__(cl_logits_input_dim=FLAGS.rnn_cell_size * 2)
# Reverse LSTM and LM loss for bidirectional models
self.layers['lstm_reverse'] = layers_lib.LSTM(
FLAGS.rnn_cell_size,
FLAGS.rnn_num_layers,
FLAGS.keep_prob_lstm_out,
name='LSTM_Reverse')
self.layers['lm_loss_reverse'] = layers_lib.SoftmaxLoss(
FLAGS.vocab_size,
FLAGS.num_candidate_samples,
self.vocab_freqs,
name='LM_loss_reverse')
@property
def pretrained_variables(self):
variables = super(VatxtBidirModel, self).pretrained_variables
variables.extend(self.layers['lstm_reverse'].trainable_weights)
return variables
def classifier_graph(self):
"""Constructs classifier graph from inputs to classifier loss.
* Caches the VatxtInput objects in `self.cl_inputs`
* Caches tensors: `cl_embedded` (tuple of forward and reverse), `cl_logits`,
`cl_loss`
Returns:
loss: scalar float.
"""
inputs = _inputs('train', pretrain=False, bidir=True)
self.cl_inputs = inputs
f_inputs, _ = inputs
# Embed both forward and reverse with a shared embedding
embedded = [self.layers['embedding'](inp.tokens) for inp in inputs]
self.tensors['cl_embedded'] = embedded
_, next_states, logits, loss = self.cl_loss_from_embedding(
embedded, return_intermediates=True)
tf.summary.scalar('classification_loss', loss)
self.tensors['cl_logits'] = logits
self.tensors['cl_loss'] = loss
acc = layers_lib.accuracy(logits, f_inputs.labels, f_inputs.weights)
tf.summary.scalar('accuracy', acc)
adv_loss = (self.adversarial_loss() * tf.constant(
FLAGS.adv_reg_coeff, name='adv_reg_coeff'))
tf.summary.scalar('adversarial_loss', adv_loss)
total_loss = loss + adv_loss
tf.summary.scalar('total_classification_loss', total_loss)
saves = [inp.save_state(state) for (inp, state) in zip(inputs, next_states)]
with tf.control_dependencies(saves):
total_loss = tf.identity(total_loss)
return total_loss
def language_model_graph(self, compute_loss=True):
"""Constructs forward and reverse LM graphs from inputs to LM losses.
* Caches the VatxtInput objects in `self.lm_inputs`
* Caches tensors: `lm_embedded`, `lm_embedded_reverse`
Args:
compute_loss: bool, whether to compute and return the loss or stop after
the LSTM computation.
Returns:
loss: scalar float, sum of forward and reverse losses.
"""
inputs = _inputs('train', pretrain=True, bidir=True)
self.lm_inputs = inputs
f_inputs, r_inputs = inputs
f_loss = self._lm_loss(f_inputs, compute_loss=compute_loss)
r_loss = self._lm_loss(
r_inputs,
emb_key='lm_embedded_reverse',
lstm_layer='lstm_reverse',
lm_loss_layer='lm_loss_reverse',
loss_name='lm_loss_reverse',
compute_loss=compute_loss)
if compute_loss:
return f_loss + r_loss
def eval_graph(self, dataset='test'):
"""Constructs classifier evaluation graph.
Args:
dataset: the labeled dataset to evaluate, {'train', 'test', 'valid'}.
Returns:
eval_ops: dict<metric name, tuple(value, update_op)>
var_restore_dict: dict mapping variable restoration names to variables.
Trainable variables will be mapped to their moving average names.
"""
inputs = _inputs(dataset, pretrain=False, bidir=True)
embedded = [self.layers['embedding'](inp.tokens) for inp in inputs]
_, next_states, logits, _ = self.cl_loss_from_embedding(
embedded, inputs=inputs, return_intermediates=True)
f_inputs, _ = inputs
eval_ops = {
'accuracy':
tf.contrib.metrics.streaming_accuracy(
layers_lib.predictions(logits), f_inputs.labels,
f_inputs.weights)
}
# Save states on accuracy update
saves = [inp.save_state(state) for (inp, state) in zip(inputs, next_states)]
with tf.control_dependencies(saves):
acc, acc_update = eval_ops['accuracy']
acc_update = tf.identity(acc_update)
eval_ops['accuracy'] = (acc, acc_update)
var_restore_dict = make_restore_average_vars_dict()
return eval_ops, var_restore_dict
def cl_loss_from_embedding(self,
embedded,
inputs=None,
return_intermediates=False):
"""Compute classification loss from embedding.
Args:
embedded: Length 2 tuple of 3-D float Tensor
[batch_size, num_timesteps, embedding_dim].
inputs: Length 2 tuple of VatxtInput, defaults to self.cl_inputs.
return_intermediates: bool, whether to return intermediate tensors or only
the final loss.
Returns:
If return_intermediates is True:
lstm_out, next_states, logits, loss
Else:
loss
"""
if inputs is None:
inputs = self.cl_inputs
out = []
for (layer_name, emb, inp) in zip(['lstm', 'lstm_reverse'], embedded,
inputs):
out.append(self.layers[layer_name](emb, inp.state, inp.length))
lstm_outs, next_states = zip(*out)
# Concatenate output of forward and reverse LSTMs
lstm_out = tf.concat(lstm_outs, 1)
logits = self.layers['cl_logits'](lstm_out)
f_inputs, _ = inputs # pylint: disable=unpacking-non-sequence
loss = layers_lib.classification_loss(logits, f_inputs.labels,
f_inputs.weights)
if return_intermediates:
return lstm_out, next_states, logits, loss
else:
return loss
def adversarial_loss(self):
"""Compute adversarial loss based on FLAGS.adv_training_method."""
def random_perturbation_loss():
return adv_lib.random_perturbation_loss_bidir(self.tensors['cl_embedded'],
self.cl_inputs[0].length,
self.cl_loss_from_embedding)
def adversarial_loss():
return adv_lib.adversarial_loss_bidir(self.tensors['cl_embedded'],
self.tensors['cl_loss'],
self.cl_loss_from_embedding)
def virtual_adversarial_loss():
"""Computes virtual adversarial loss.
Uses lm_inputs and constructs the language model graph if it hasn't yet
been constructed.
Also ensures that the LM input states are saved for LSTM state-saving
BPTT.
Returns:
loss: float scalar.
"""
if self.lm_inputs is None:
self.language_model_graph(compute_loss=False)
def logits_from_embedding(embedded, return_next_state=False):
_, next_states, logits, _ = self.cl_loss_from_embedding(
embedded, inputs=self.lm_inputs, return_intermediates=True)
if return_next_state:
return next_states, logits
else:
return logits
lm_embedded = (self.tensors['lm_embedded'],
self.tensors['lm_embedded_reverse'])
next_states, lm_cl_logits = logits_from_embedding(
lm_embedded, return_next_state=True)
va_loss = adv_lib.virtual_adversarial_loss_bidir(
lm_cl_logits, lm_embedded, self.lm_inputs, logits_from_embedding)
saves = [
inp.save_state(state)
for (inp, state) in zip(self.lm_inputs, next_states)
]
with tf.control_dependencies(saves):
va_loss = tf.identity(va_loss)
return va_loss
def combo_loss():
return adversarial_loss() + virtual_adversarial_loss()
adv_training_methods = {
# Random perturbation
'rp': random_perturbation_loss,
# Adversarial training
'at': adversarial_loss,
# Virtual adversarial training
'vat': virtual_adversarial_loss,
# Both at and vat
'atvat': combo_loss,
'': lambda: tf.constant(0.),
None: lambda: tf.constant(0.),
}
with tf.name_scope('adversarial_loss'):
return adv_training_methods[FLAGS.adv_training_method]()
def _inputs(dataset='train', pretrain=False, bidir=False):
return inputs_lib.inputs(
data_dir=FLAGS.data_dir,
phase=dataset,
bidir=bidir,
pretrain=pretrain,
use_seq2seq=pretrain and FLAGS.use_seq2seq_autoencoder,
state_size=FLAGS.rnn_cell_size,
num_layers=FLAGS.rnn_num_layers,
batch_size=FLAGS.batch_size,
unroll_steps=FLAGS.num_timesteps)
def _get_vocab_freqs():
"""Returns vocab frequencies.
Returns:
List of integers, length=FLAGS.vocab_size.
Raises:
ValueError: if the length of the frequency file is not equal to the vocab
size, or if the file is not found.
"""
path = FLAGS.vocab_freq_path or os.path.join(FLAGS.data_dir, 'vocab_freq.txt')
if tf.gfile.Exists(path):
with tf.gfile.Open(path) as f:
# Get pre-calculated frequencies of words.
reader = csv.reader(f, quoting=csv.QUOTE_NONE)
freqs = [int(row[-1]) for row in reader]
if len(freqs) != FLAGS.vocab_size:
raise ValueError('Frequency file length %d != vocab size %d' %
(len(freqs), FLAGS.vocab_size))
else:
if FLAGS.vocab_freq_path:
raise ValueError('vocab_freq_path not found')
freqs = [1] * FLAGS.vocab_size
return freqs
def make_restore_average_vars_dict():
"""Returns dict mapping moving average names to variables."""
var_restore_dict = {}
variable_averages = tf.train.ExponentialMovingAverage(0.999)
for v in tf.global_variables():
if v in tf.trainable_variables():
name = variable_averages.average_name(v)
else:
name = v.op.name
var_restore_dict[name] = v
return var_restore_dict
def optimize(loss, global_step):
return layers_lib.optimize(
loss, global_step, FLAGS.max_grad_norm, FLAGS.learning_rate,
FLAGS.learning_rate_decay_factor, FLAGS.sync_replicas,
FLAGS.replicas_to_aggregate, FLAGS.task)
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for graphs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import defaultdict
import operator
import os
import random
import shutil
import string
import tempfile
import tensorflow as tf
import graphs
from adversarial_text.data import data_utils
flags = tf.app.flags
FLAGS = flags.FLAGS
data = data_utils
flags.DEFINE_integer('task', 0, 'Task id; needed for SyncReplicas test')
def _build_random_vocabulary(vocab_size=100):
"""Builds and returns a dict<term, id>."""
vocab = set()
while len(vocab) < (vocab_size - 1):
rand_word = ''.join(
random.choice(string.ascii_lowercase)
for _ in range(random.randint(1, 10)))
vocab.add(rand_word)
vocab_ids = dict([(word, i) for i, word in enumerate(vocab)])
vocab_ids[data.EOS_TOKEN] = vocab_size - 1
return vocab_ids
def _build_random_sequence(vocab_ids):
seq_len = random.randint(10, 200)
ids = vocab_ids.values()
seq = data.SequenceWrapper()
for token_id in [random.choice(ids) for _ in range(seq_len)]:
seq.add_timestep().set_token(token_id)
return seq
def _build_vocab_frequencies(seqs, vocab_ids):
vocab_freqs = defaultdict(int)
ids_to_words = dict([(i, word) for word, i in vocab_ids.iteritems()])
for seq in seqs:
for timestep in seq:
vocab_freqs[ids_to_words[timestep.token]] += 1
vocab_freqs[data.EOS_TOKEN] = 0
return vocab_freqs
class GraphsTest(tf.test.TestCase):
"""Test graph construction methods."""
@classmethod
def setUpClass(cls):
# Make model small
FLAGS.batch_size = 2
FLAGS.num_timesteps = 3
FLAGS.embedding_dims = 4
FLAGS.rnn_num_layers = 2
FLAGS.rnn_cell_size = 4
FLAGS.cl_num_layers = 2
FLAGS.cl_hidden_size = 4
FLAGS.vocab_size = 10
# Set input/output flags
FLAGS.data_dir = tempfile.mkdtemp()
# Build and write sequence files.
vocab_ids = _build_random_vocabulary(FLAGS.vocab_size)
seqs = [_build_random_sequence(vocab_ids) for _ in range(5)]
seqs_label = [
data.build_labeled_sequence(seq, random.choice([True, False]))
for seq in seqs
]
seqs_lm = [data.build_lm_sequence(seq) for seq in seqs]
seqs_ae = [data.build_seq_ae_sequence(seq) for seq in seqs]
seqs_rev = [data.build_reverse_sequence(seq) for seq in seqs]
seqs_bidir = [
data.build_bidirectional_seq(seq, rev)
for seq, rev in zip(seqs, seqs_rev)
]
seqs_bidir_label = [
data.build_labeled_sequence(bd_seq, random.choice([True, False]))
for bd_seq in seqs_bidir
]
filenames = [
data.TRAIN_CLASS, data.TRAIN_LM, data.TRAIN_SA, data.TEST_CLASS,
data.TRAIN_REV_LM, data.TRAIN_BD_CLASS, data.TEST_BD_CLASS
]
seq_lists = [
seqs_label, seqs_lm, seqs_ae, seqs_label, seqs_rev, seqs_bidir,
seqs_bidir_label
]
for fname, seq_list in zip(filenames, seq_lists):
with tf.python_io.TFRecordWriter(
os.path.join(FLAGS.data_dir, fname)) as writer:
for seq in seq_list:
writer.write(seq.seq.SerializeToString())
# Write vocab.txt and vocab_freq.txt
vocab_freqs = _build_vocab_frequencies(seqs, vocab_ids)
ordered_vocab_freqs = sorted(
vocab_freqs.items(), key=operator.itemgetter(1), reverse=True)
with open(os.path.join(FLAGS.data_dir, 'vocab.txt'), 'w') as vocab_f:
with open(os.path.join(FLAGS.data_dir, 'vocab_freq.txt'), 'w') as freq_f:
for word, freq in ordered_vocab_freqs:
vocab_f.write('{}\n'.format(word))
freq_f.write('{}\n'.format(freq))
@classmethod
def tearDownClass(cls):
shutil.rmtree(FLAGS.data_dir)
def setUp(self):
# Reset FLAGS
FLAGS.rnn_num_layers = 1
FLAGS.sync_replicas = False
FLAGS.adv_training_method = None
FLAGS.num_candidate_samples = -1
FLAGS.num_classes = 2
FLAGS.use_seq2seq_autoencoder = False
# Reset Graph
tf.reset_default_graph()
def testClassifierGraph(self):
FLAGS.rnn_num_layers = 2
model = graphs.VatxtModel()
train_op, _, _ = model.classifier_training()
# Pretrained vars: embedding + LSTM layers
self.assertEqual(
len(model.pretrained_variables), 1 + 2 * FLAGS.rnn_num_layers)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess)
sess.run(train_op)
def testLanguageModelGraph(self):
train_op, _, _ = graphs.VatxtModel().language_model_training()
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess)
sess.run(train_op)
def testMulticlass(self):
FLAGS.num_classes = 10
graphs.VatxtModel().classifier_graph()
def testATMethods(self):
at_methods = [None, 'rp', 'at', 'vat', 'atvat']
for method in at_methods:
FLAGS.adv_training_method = method
with tf.Graph().as_default():
graphs.VatxtModel().classifier_graph()
# Ensure variables have been reused
# Embedding + LSTM layers + hidden layers + logits layer
expected_num_vars = 1 + 2 * FLAGS.rnn_num_layers + 2 * (
FLAGS.cl_num_layers) + 2
self.assertEqual(len(tf.trainable_variables()), expected_num_vars)
def testSyncReplicas(self):
FLAGS.sync_replicas = True
graphs.VatxtModel().language_model_training()
def testCandidateSampling(self):
FLAGS.num_candidate_samples = 10
graphs.VatxtModel().language_model_training()
def testSeqAE(self):
FLAGS.use_seq2seq_autoencoder = True
graphs.VatxtModel().language_model_training()
def testBidirLM(self):
graphs.VatxtBidirModel().language_model_graph()
def testBidirClassifier(self):
at_methods = [None, 'rp', 'at', 'vat', 'atvat']
for method in at_methods:
FLAGS.adv_training_method = method
with tf.Graph().as_default():
graphs.VatxtBidirModel().classifier_graph()
# Ensure variables have been reused
# Embedding + 2 LSTM layers + hidden layers + logits layer
expected_num_vars = 1 + 2 * 2 * FLAGS.rnn_num_layers + 2 * (
FLAGS.cl_num_layers) + 2
self.assertEqual(len(tf.trainable_variables()), expected_num_vars)
def testEvalGraph(self):
_, _ = graphs.VatxtModel().eval_graph()
def testBidirEvalGraph(self):
_, _ = graphs.VatxtBidirModel().eval_graph()
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input utils for virtual adversarial text classification."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
from adversarial_text.data import data_utils
class VatxtInput(object):
"""Wrapper around NextQueuedSequenceBatch."""
def __init__(self, batch, state_name=None, tokens=None, num_states=0):
"""Construct VatxtInput.
Args:
batch: NextQueuedSequenceBatch.
state_name: str, name of state to fetch and save.
tokens: int Tensor, tokens. Defaults to batch's F_TOKEN_ID sequence.
num_states: int The number of states to store.
"""
self._batch = batch
self._state_name = state_name
self._tokens = (tokens if tokens is not None else
batch.sequences[data_utils.SequenceWrapper.F_TOKEN_ID])
self._num_states = num_states
# Once the tokens have passed through embedding and LSTM, the output Tensor
# shapes will be time-major, i.e. shape = (time, batch, dim). Here we make
# both weights and labels time-major with a transpose, and then merge the
# time and batch dimensions such that they are both vectors of shape
# (time*batch).
w = batch.sequences[data_utils.SequenceWrapper.F_WEIGHT]
w = tf.transpose(w, [1, 0])
w = tf.reshape(w, [-1])
self._weights = w
l = batch.sequences[data_utils.SequenceWrapper.F_LABEL]
l = tf.transpose(l, [1, 0])
l = tf.reshape(l, [-1])
self._labels = l
@property
def tokens(self):
return self._tokens
@property
def weights(self):
return self._weights
@property
def labels(self):
return self._labels
@property
def length(self):
return self._batch.length
@property
def state_name(self):
return self._state_name
@property
def state(self):
# LSTM tuple states
state_names = _get_tuple_state_names(self._num_states, self._state_name)
return tuple([
tf.contrib.rnn.LSTMStateTuple(
self._batch.state(c_name), self._batch.state(h_name))
for c_name, h_name in state_names
])
def save_state(self, value):
# LSTM tuple states
state_names = _get_tuple_state_names(self._num_states, self._state_name)
save_ops = []
for (c_state, h_state), (c_name, h_name) in zip(value, state_names):
save_ops.append(self._batch.save_state(c_name, c_state))
save_ops.append(self._batch.save_state(h_name, h_state))
return tf.group(*save_ops)
def _get_tuple_state_names(num_states, base_name):
"""Returns state names for use with LSTM tuple state."""
state_names = [('{}_{}_c'.format(i, base_name), '{}_{}_h'.format(
i, base_name)) for i in range(num_states)]
return state_names
def _split_bidir_tokens(batch):
tokens = batch.sequences[data_utils.SequenceWrapper.F_TOKEN_ID]
# Tokens have shape [batch, time, 2]
# forward and reverse have shape [batch, time].
forward, reverse = [
tf.squeeze(t, axis=[2]) for t in tf.split(tokens, 2, axis=2)
]
return forward, reverse
def _filenames_for_data_spec(phase, bidir, pretrain, use_seq2seq):
"""Returns input filenames for configuration.
Args:
phase: str, 'train', 'test', or 'valid'.
bidir: bool, bidirectional model.
pretrain: bool, pretraining or classification.
use_seq2seq: bool, seq2seq data, only valid if pretrain=True.
Returns:
Tuple of filenames.
Raises:
ValueError: if an invalid combination of arguments is provided that does not
map to any data files (e.g. pretrain=False, use_seq2seq=True).
"""
data_spec = (phase, bidir, pretrain, use_seq2seq)
data_specs = {
('train', True, True, False): (data_utils.TRAIN_LM,
data_utils.TRAIN_REV_LM),
('train', True, False, False): (data_utils.TRAIN_BD_CLASS,),
('train', False, True, False): (data_utils.TRAIN_LM,),
('train', False, True, True): (data_utils.TRAIN_SA,),
('train', False, False, False): (data_utils.TRAIN_CLASS,),
('test', True, True, False): (data_utils.TEST_LM,
data_utils.TRAIN_REV_LM),
('test', True, False, False): (data_utils.TEST_BD_CLASS,),
('test', False, True, False): (data_utils.TEST_LM,),
('test', False, True, True): (data_utils.TEST_SA,),
('test', False, False, False): (data_utils.TEST_CLASS,),
('valid', True, False, False): (data_utils.VALID_BD_CLASS,),
('valid', False, False, False): (data_utils.VALID_CLASS,),
}
if data_spec not in data_specs:
raise ValueError(
'Data specification (phase, bidir, pretrain, use_seq2seq) %s not '
'supported' % str(data_spec))
return data_specs[data_spec]
def _read_single_sequence_example(file_list, tokens_shape=None):
"""Reads and parses SequenceExamples from TFRecord-encoded file_list."""
tf.logging.info('Constructing TFRecordReader from files: %s', file_list)
file_queue = tf.train.string_input_producer(file_list)
reader = tf.TFRecordReader()
seq_key, serialized_record = reader.read(file_queue)
ctx, sequence = tf.parse_single_sequence_example(
serialized_record,
sequence_features={
data_utils.SequenceWrapper.F_TOKEN_ID:
tf.FixedLenSequenceFeature(tokens_shape or [], dtype=tf.int64),
data_utils.SequenceWrapper.F_LABEL:
tf.FixedLenSequenceFeature([], dtype=tf.int64),
data_utils.SequenceWrapper.F_WEIGHT:
tf.FixedLenSequenceFeature([], dtype=tf.float32),
})
return seq_key, ctx, sequence
def _read_and_batch(data_dir,
fname,
state_name,
state_size,
num_layers,
unroll_steps,
batch_size,
bidir_input=False):
"""Inputs for text model.
Args:
data_dir: str, directory containing TFRecord files of SequenceExample.
fname: str, input file name.
state_name: string, key for saved state of LSTM.
state_size: int, size of LSTM state.
num_layers: int, the number of layers in the LSTM.
unroll_steps: int, number of timesteps to unroll for TBTT.
batch_size: int, batch size.
bidir_input: bool, whether the input is bidirectional. If True, creates 2
states, state_name and state_name + '_reverse'.
Returns:
Instance of NextQueuedSequenceBatch
Raises:
ValueError: if file for input specification is not found.
"""
data_path = os.path.join(data_dir, fname)
if not tf.gfile.Exists(data_path):
raise ValueError('Failed to find file: %s' % data_path)
tokens_shape = [2] if bidir_input else []
seq_key, ctx, sequence = _read_single_sequence_example(
[data_path], tokens_shape=tokens_shape)
# Set up stateful queue reader.
state_names = _get_tuple_state_names(num_layers, state_name)
initial_states = {}
for c_state, h_state in state_names:
initial_states[c_state] = tf.zeros(state_size)
initial_states[h_state] = tf.zeros(state_size)
if bidir_input:
rev_state_names = _get_tuple_state_names(num_layers,
'{}_reverse'.format(state_name))
for rev_c_state, rev_h_state in rev_state_names:
initial_states[rev_c_state] = tf.zeros(state_size)
initial_states[rev_h_state] = tf.zeros(state_size)
batch = tf.contrib.training.batch_sequences_with_states(
input_key=seq_key,
input_sequences=sequence,
input_context=ctx,
input_length=tf.shape(sequence['token_id'])[0],
initial_states=initial_states,
num_unroll=unroll_steps,
batch_size=batch_size,
allow_small_batch=False,
num_threads=4,
capacity=batch_size * 10,
make_keys_unique=True,
make_keys_unique_seed=29392)
return batch
def inputs(data_dir=None,
phase='train',
bidir=False,
pretrain=False,
use_seq2seq=False,
state_name='lstm',
state_size=None,
num_layers=0,
batch_size=32,
unroll_steps=100):
"""Inputs for text model.
Args:
data_dir: str, directory containing TFRecord files of SequenceExample.
phase: str, dataset for evaluation {'train', 'valid', 'test'}.
bidir: bool, bidirectional LSTM.
pretrain: bool, whether to read pretraining data or classification data.
use_seq2seq: bool, whether to read seq2seq data or the language model data.
state_name: string, key for saved state of LSTM.
state_size: int, size of LSTM state.
num_layers: int, the number of LSTM layers.
batch_size: int, batch size.
unroll_steps: int, number of timesteps to unroll for TBTT.
Returns:
Instance of VatxtInput (x2 if bidir=True and pretrain=True, i.e. forward and
reverse).
"""
with tf.name_scope('inputs'):
filenames = _filenames_for_data_spec(phase, bidir, pretrain, use_seq2seq)
if bidir and pretrain:
# Bidirectional pretraining
# Requires separate forward and reverse language model data.
forward_fname, reverse_fname = filenames
forward_batch = _read_and_batch(data_dir, forward_fname, state_name,
state_size, num_layers, unroll_steps,
batch_size)
state_name_rev = state_name + '_reverse'
reverse_batch = _read_and_batch(data_dir, reverse_fname, state_name_rev,
state_size, num_layers, unroll_steps,
batch_size)
forward_input = VatxtInput(
forward_batch, state_name=state_name, num_states=num_layers)
reverse_input = VatxtInput(
reverse_batch, state_name=state_name_rev, num_states=num_layers)
return forward_input, reverse_input
elif bidir:
# Classifier bidirectional LSTM
# Shared data source, but separate token/state streams
fname, = filenames
batch = _read_and_batch(
data_dir,
fname,
state_name,
state_size,
num_layers,
unroll_steps,
batch_size,
bidir_input=True)
forward_tokens, reverse_tokens = _split_bidir_tokens(batch)
forward_input = VatxtInput(
batch,
state_name=state_name,
tokens=forward_tokens,
num_states=num_layers)
reverse_input = VatxtInput(
batch,
state_name=state_name + '_reverse',
tokens=reverse_tokens,
num_states=num_layers)
return forward_input, reverse_input
else:
# Unidirectional LM or classifier
fname, = filenames
batch = _read_and_batch(
data_dir,
fname,
state_name,
state_size,
num_layers,
unroll_steps,
batch_size,
bidir_input=False)
return VatxtInput(batch, state_name=state_name, num_states=num_layers)
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for VatxtModel."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
K = tf.contrib.keras
def cl_logits_subgraph(layer_sizes, input_size, num_classes, keep_prob=1.):
"""Construct multiple ReLU layers with dropout and a linear layer."""
subgraph = K.models.Sequential(name='cl_logits')
for i, layer_size in enumerate(layer_sizes):
if i == 0:
subgraph.add(
K.layers.Dense(layer_size, activation='relu', input_dim=input_size))
else:
subgraph.add(K.layers.Dense(layer_size, activation='relu'))
if keep_prob < 1.:
subgraph.add(K.layers.Dropout(keep_prob))
subgraph.add(K.layers.Dense(1 if num_classes == 2 else num_classes))
return subgraph
class Embedding(K.layers.Layer):
"""Embedding layer with frequency-based normalization and dropout."""
def __init__(self,
vocab_size,
embedding_dim,
normalize=False,
vocab_freqs=None,
keep_prob=1.,
**kwargs):
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.normalized = normalize
self.keep_prob = keep_prob
if normalize:
assert vocab_freqs is not None
self.vocab_freqs = tf.constant(
vocab_freqs, dtype=tf.float32, shape=(vocab_size, 1))
super(Embedding, self).__init__(**kwargs)
def build(self, input_shape):
with tf.device('/cpu:0'):
self.var = self.add_weight(
shape=(self.vocab_size, self.embedding_dim),
initializer=tf.random_uniform_initializer(-1., 1.),
name='embedding')
if self.normalized:
self.var = self._normalize(self.var)
super(Embedding, self).build(input_shape)
def call(self, x):
embedded = tf.nn.embedding_lookup(self.var, x)
if self.keep_prob < 1.:
embedded = tf.nn.dropout(embedded, self.keep_prob)
return embedded
def _normalize(self, emb):
weights = self.vocab_freqs / tf.reduce_sum(self.vocab_freqs)
emb -= tf.reduce_sum(weights * emb, 0, keep_dims=True)
emb /= tf.sqrt(1e-6 + tf.reduce_sum(
weights * tf.pow(emb, 2.), 0, keep_dims=True))
return emb
class LSTM(object):
"""LSTM layer using static_rnn.
Exposes variables in `trainable_weights` property.
"""
def __init__(self, cell_size, num_layers=1, keep_prob=1., name='LSTM'):
self.cell_size = cell_size
self.num_layers = num_layers
self.keep_prob = keep_prob
self.reuse = None
self.trainable_weights = None
self.name = name
def __call__(self, x, initial_state, seq_length):
with tf.variable_scope(self.name, reuse=self.reuse) as vs:
cell = tf.contrib.rnn.MultiRNNCell([
tf.contrib.rnn.BasicLSTMCell(
self.cell_size,
forget_bias=0.0,
reuse=tf.get_variable_scope().reuse)
for _ in xrange(self.num_layers)
])
# shape(x) = (batch_size, num_timesteps, embedding_dim)
# Convert into a time-major list for static_rnn
x = tf.unstack(tf.transpose(x, perm=[1, 0, 2]))
lstm_out, next_state = tf.contrib.rnn.static_rnn(
cell, x, initial_state=initial_state, sequence_length=seq_length)
# Merge time and batch dimensions
# shape(lstm_out) = timesteps * (batch_size, cell_size)
lstm_out = tf.concat(lstm_out, 0)
# shape(lstm_out) = (timesteps*batch_size, cell_size)
if self.keep_prob < 1.:
lstm_out = tf.nn.dropout(lstm_out, self.keep_prob)
if self.reuse is None:
self.trainable_weights = vs.global_variables()
self.reuse = True
return lstm_out, next_state
class SoftmaxLoss(K.layers.Layer):
"""Softmax xentropy loss with candidate sampling."""
def __init__(self,
vocab_size,
num_candidate_samples=-1,
vocab_freqs=None,
**kwargs):
self.vocab_size = vocab_size
self.num_candidate_samples = num_candidate_samples
self.vocab_freqs = vocab_freqs
super(SoftmaxLoss, self).__init__(**kwargs)
def build(self, input_shape):
input_shape = input_shape[0]
with tf.device('/cpu:0'):
self.lin_w = self.add_weight(
shape=(input_shape[-1], self.vocab_size),
name='lm_lin_w',
initializer='glorot_uniform')
self.lin_b = self.add_weight(
shape=(self.vocab_size,),
name='lm_lin_b',
initializer='glorot_uniform')
super(SoftmaxLoss, self).build(input_shape)
def call(self, inputs):
x, labels, weights = inputs
if self.num_candidate_samples > -1:
assert self.vocab_freqs is not None
labels = tf.expand_dims(labels, -1)
sampled = tf.nn.fixed_unigram_candidate_sampler(
true_classes=labels,
num_true=1,
num_sampled=self.num_candidate_samples,
unique=True,
range_max=self.vocab_size,
unigrams=self.vocab_freqs)
lm_loss = tf.nn.sampled_softmax_loss(
weights=tf.transpose(self.lin_w),
biases=self.lin_b,
labels=labels,
inputs=x,
num_sampled=self.num_candidate_samples,
num_classes=self.vocab_size,
sampled_values=sampled)
else:
logits = tf.matmul(x, self.lin_w) + self.lin_b
lm_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels)
lm_loss = tf.identity(
tf.reduce_sum(lm_loss * weights) / _num_labels(weights),
name='lm_xentropy_loss')
return lm_loss
def classification_loss(logits, labels, weights):
"""Computes cross entropy loss between logits and labels.
Args:
logits: 2-D [timesteps*batch_size, m] float tensor, where m=1 if
num_classes=2, otherwise m=num_classes.
labels: 1-D [timesteps*batch_size] integer tensor.
weights: 2-D [timesteps*batch_size] float tensor.
Returns:
Loss scalar of type float.
"""
inner_dim = logits.get_shape().as_list()[-1]
with tf.name_scope('classifier_loss'):
# Logistic loss
if inner_dim == 1:
loss = tf.nn.sigmoid_cross_entropy_with_logits(
logits=tf.squeeze(logits), labels=tf.cast(labels, tf.float32))
# Softmax loss
else:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels)
num_lab = _num_labels(weights)
tf.summary.scalar('num_labels', num_lab)
return tf.identity(
tf.reduce_sum(weights * loss) / num_lab, name='classification_xentropy')
def accuracy(logits, targets, weights):
"""Computes prediction accuracy.
Args:
logits: 2-D classifier logits [timesteps*batch_size, num_classes]
targets: 1-D [timesteps*batch_size] integer tensor.
weights: 1-D [timesteps*batch_size] float tensor.
Returns:
Accuracy: float scalar.
"""
with tf.name_scope('accuracy'):
eq = tf.cast(tf.equal(predictions(logits), targets), tf.float32)
return tf.identity(
tf.reduce_sum(weights * eq) / _num_labels(weights), name='accuracy')
def predictions(logits):
"""Class prediction from logits."""
inner_dim = logits.get_shape().as_list()[-1]
with tf.name_scope('predictions'):
# For binary classification
if inner_dim == 1:
pred = tf.cast(tf.greater(tf.squeeze(logits), 0.5), tf.int64)
# For multi-class classification
else:
pred = tf.argmax(logits, 1)
return pred
def _num_labels(weights):
"""Number of 1's in weights. Returns 1. if 0."""
num_labels = tf.reduce_sum(weights)
num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels)
return num_labels
def optimize(loss,
global_step,
max_grad_norm,
lr,
lr_decay,
sync_replicas=False,
replicas_to_aggregate=1,
task_id=0):
"""Builds optimization graph.
* Creates an optimizer, and optionally wraps with SyncReplicasOptimizer
* Computes, clips, and applies gradients
* Maintains moving averages for all trainable variables
* Summarizes variables and gradients
Args:
loss: scalar loss to minimize.
global_step: integer scalar Variable.
max_grad_norm: float scalar. Grads will be clipped to this value.
lr: float scalar, learning rate.
lr_decay: float scalar, learning rate decay rate.
sync_replicas: bool, whether to use SyncReplicasOptimizer.
replicas_to_aggregate: int, number of replicas to aggregate when using
SyncReplicasOptimizer.
task_id: int, id of the current task; used to ensure proper initialization
of SyncReplicasOptimizer.
Returns:
train_op
"""
with tf.name_scope('optimization'):
# Compute gradients.
tvars = tf.trainable_variables()
grads = tf.gradients(
loss,
tvars,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
# Clip non-embedding grads
non_embedding_grads_and_vars = [(g, v) for (g, v) in zip(grads, tvars)
if 'embedding' not in v.op.name]
embedding_grads_and_vars = [(g, v) for (g, v) in zip(grads, tvars)
if 'embedding' in v.op.name]
ne_grads, ne_vars = zip(*non_embedding_grads_and_vars)
ne_grads, _ = tf.clip_by_global_norm(ne_grads, max_grad_norm)
non_embedding_grads_and_vars = zip(ne_grads, ne_vars)
grads_and_vars = embedding_grads_and_vars + non_embedding_grads_and_vars
# Summarize
_summarize_vars_and_grads(grads_and_vars)
# Decaying learning rate
lr = tf.train.exponential_decay(
lr, global_step, 1, lr_decay, staircase=True)
tf.summary.scalar('learning_rate', lr)
opt = tf.train.AdamOptimizer(lr)
# Track the moving averages of all trainable variables.
variable_averages = tf.train.ExponentialMovingAverage(0.999, global_step)
# Apply gradients
if sync_replicas:
opt = tf.train.SyncReplicasOptimizer(
opt,
replicas_to_aggregate,
variable_averages=variable_averages,
variables_to_average=tvars,
total_num_replicas=replicas_to_aggregate)
apply_gradient_op = opt.apply_gradients(
grads_and_vars, global_step=global_step)
with tf.control_dependencies([apply_gradient_op]):
train_op = tf.no_op(name='train_op')
# Initialization ops
tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS,
opt.get_chief_queue_runner())
if task_id == 0: # Chief task
local_init_op = opt.chief_init_op
tf.add_to_collection('chief_init_op', opt.get_init_tokens_op())
else:
local_init_op = opt.local_step_init_op
tf.add_to_collection('local_init_op', local_init_op)
tf.add_to_collection('ready_for_local_init_op',
opt.ready_for_local_init_op)
else:
# Non-sync optimizer
variables_averages_op = variable_averages.apply(tvars)
apply_gradient_op = opt.apply_gradients(grads_and_vars, global_step)
with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
train_op = tf.no_op(name='train_op')
return train_op
def _summarize_vars_and_grads(grads_and_vars):
tf.logging.info('Trainable variables:')
tf.logging.info('-' * 60)
for grad, var in grads_and_vars:
tf.logging.info(var)
def tag(name, v=var):
return v.op.name + '_' + name
# Variable summary
mean = tf.reduce_mean(var)
tf.summary.scalar(tag('mean'), mean)
with tf.name_scope(tag('stddev')):
stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
tf.summary.scalar(tag('stddev'), stddev)
tf.summary.scalar(tag('max'), tf.reduce_max(var))
tf.summary.scalar(tag('min'), tf.reduce_min(var))
tf.summary.histogram(tag('histogram'), var)
# Gradient summary
if grad is not None:
if isinstance(grad, tf.IndexedSlices):
grad_values = grad.values
else:
grad_values = grad
tf.summary.histogram(tag('gradient'), grad_values)
tf.summary.scalar(tag('gradient_norm'), tf.global_norm([grad_values]))
else:
tf.logging.info('Var %s has no gradient', var.op.name)
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pretrains a recurrent language model.
Computational time:
5 days to train 100000 steps on 1 layer 1024 hidden units LSTM,
256 embeddings, 400 truncated BP, 64 minibatch and on 4 GPU with
SyncReplicasOptimizer, that is the total minibatch is 256.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import graphs
import train_utils
FLAGS = tf.app.flags.FLAGS
def main(_):
"""Trains Language Model."""
tf.logging.set_verbosity(tf.logging.INFO)
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
model = graphs.get_model()
train_op, loss, global_step = model.language_model_training()
train_utils.run_training(train_op, loss, global_step)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains LSTM text classification model.
Model trains with adversarial or virtual adversarial training.
Computational time:
6 hours to train 10000 steps without adversarial or virtual adversarial
training, on 1 layer 1024 hidden units LSTM, 256 embeddings, 400 truncated
BP, 64 minibatch and on single GPU.
12 hours to train 10000 steps with adversarial or virtual adversarial
training, with above condition.
To initialize embedding and LSTM cell weights from a pretrained model, set
FLAGS.pretrained_model_dir to the pretrained model's checkpoint directory.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import graphs
import train_utils
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('pretrained_model_dir', None,
'Directory path to pretrained model to restore from')
def main(_):
"""Trains LSTM classification model."""
tf.logging.set_verbosity(tf.logging.INFO)
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
model = graphs.get_model()
train_op, loss, global_step = model.classifier_training()
train_utils.run_training(
train_op,
loss,
global_step,
variables_to_restore=model.pretrained_variables,
pretrained_model_dir=FLAGS.pretrained_model_dir)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for training adversarial text models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'Master address.')
flags.DEFINE_integer('task', 0, 'Task id of the replica running the training.')
flags.DEFINE_integer('ps_tasks', 0, 'Number of parameter servers.')
flags.DEFINE_string('train_dir', '/tmp/text_train',
'Directory for logs and checkpoints.')
flags.DEFINE_integer('max_steps', 1000000, 'Number of batches to run.')
flags.DEFINE_boolean('log_device_placement', False,
'Whether to log device placement.')
def run_training(train_op,
loss,
global_step,
variables_to_restore=None,
pretrained_model_dir=None):
"""Sets up and runs training loop."""
tf.gfile.MakeDirs(FLAGS.train_dir)
# Create pretrain Saver
if pretrained_model_dir:
assert variables_to_restore
tf.logging.info('Will attempt restore from %s: %s', pretrained_model_dir,
variables_to_restore)
saver_for_restore = tf.train.Saver(variables_to_restore)
# Init ops
if FLAGS.sync_replicas:
local_init_op = tf.get_collection('local_init_op')[0]
ready_for_local_init_op = tf.get_collection('ready_for_local_init_op')[0]
else:
local_init_op = tf.train.Supervisor.USE_DEFAULT
ready_for_local_init_op = tf.train.Supervisor.USE_DEFAULT
is_chief = FLAGS.task == 0
sv = tf.train.Supervisor(
logdir=FLAGS.train_dir,
is_chief=is_chief,
save_summaries_secs=5 * 60,
save_model_secs=5 * 60,
local_init_op=local_init_op,
ready_for_local_init_op=ready_for_local_init_op,
global_step=global_step)
# Delay starting standard services to allow possible pretrained model restore.
with sv.managed_session(
master=FLAGS.master,
config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement),
start_standard_services=False) as sess:
# Initialization
if is_chief:
if pretrained_model_dir:
maybe_restore_pretrained_model(sess, saver_for_restore,
pretrained_model_dir)
if FLAGS.sync_replicas:
sess.run(tf.get_collection('chief_init_op')[0])
sv.start_standard_services(sess)
sv.start_queue_runners(sess)
# Training loop
global_step_val = 0
while not sv.should_stop() and global_step_val < FLAGS.max_steps:
global_step_val = train_step(sess, train_op, loss, global_step)
sv.stop()
# Final checkpoint
if is_chief:
sv.saver.save(sess, sv.save_path, global_step=global_step)
def maybe_restore_pretrained_model(sess, saver_for_restore, model_dir):
"""Restores pretrained model if there is no ckpt model."""
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
checkpoint_exists = ckpt and ckpt.model_checkpoint_path
if checkpoint_exists:
tf.logging.info('Checkpoint exists in FLAGS.train_dir; skipping '
'pretraining restore')
return
pretrain_ckpt = tf.train.get_checkpoint_state(model_dir)
if not (pretrain_ckpt and pretrain_ckpt.model_checkpoint_path):
raise ValueError(
'Asked to restore model from %s but no checkpoint found.' % model_dir)
saver_for_restore.restore(sess, pretrain_ckpt.model_checkpoint_path)
def train_step(sess, train_op, loss, global_step):
"""Runs a single training step."""
start_time = time.time()
_, loss_val, global_step_val = sess.run([train_op, loss, global_step])
duration = time.time() - start_time
# Logging
if global_step_val % 10 == 0:
examples_per_sec = FLAGS.batch_size / duration
sec_per_batch = float(duration)
format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)')
tf.logging.info(format_str % (global_step_val, loss_val, examples_per_sec,
sec_per_batch))
if np.isnan(loss_val):
raise OverflowError('Loss is nan')
return global_step_val
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册