diff --git a/adversarial_text/BUILD b/adversarial_text/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..b0fdc6332f96e6c55101dc4c707c5c9d8609da02 --- /dev/null +++ b/adversarial_text/BUILD @@ -0,0 +1,76 @@ +# Binaries +# ============================================================================== +py_binary( + name = "evaluate", + srcs = ["evaluate.py"], + deps = [ + ":graphs", + ], +) + +py_binary( + name = "train_classifier", + srcs = ["train_classifier.py"], + deps = [ + ":graphs", + ":train_utils", + ], +) + +py_binary( + name = "pretrain", + srcs = [ + "pretrain.py", + ], + deps = [ + ":graphs", + ":train_utils", + ], +) + +# Libraries +# ============================================================================== +py_library( + name = "graphs", + srcs = ["graphs.py"], + deps = [ + ":adversarial_losses", + ":inputs", + ":layers", + ], +) + +py_library( + name = "adversarial_losses", + srcs = ["adversarial_losses.py"], +) + +py_library( + name = "inputs", + srcs = ["inputs.py"], + deps = [ + "//adversarial_text/data:data_utils", + ], +) + +py_library( + name = "layers", + srcs = ["layers.py"], +) + +py_library( + name = "train_utils", + srcs = ["train_utils.py"], +) + +# Tests +# ============================================================================== +py_test( + name = "graphs_test", + size = "large", + srcs = ["graphs_test.py"], + deps = [ + ":graphs", + "//adversarial_text/data:data_utils", + ], +) diff --git a/adversarial_text/README.md b/adversarial_text/README.md new file mode 100644 index 0000000000000000000000000000000000000000..05952c00f5bee83a1119f943bffd01ad7a5fe4b3 --- /dev/null +++ b/adversarial_text/README.md @@ -0,0 +1,159 @@ +# Adversarial Text Classification + +Code for *Adversarial Training Methods for Semi-Supervised Text Classification* +(https://arxiv.org/abs/1605.07725) and *Semi-Supervised Sequence Learning* +(https://arxiv.org/abs/1511.01432). + +## Requirements + +* Bazel ([install](https://bazel.build/versions/master/docs/install.html)) +* TensorFlow >= v1.1 + +## End-to-end IMDB Sentiment Classification + +### Fetch data + +``` +$ wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz \ + -O /tmp/imdb.tar.gz +$ tar -xf /tmp/imdb.tar.gz -C /tmp +``` + +The directory `/tmp/aclImdb` contains the raw IMDB data. + +### Generate vocabulary + +``` +$ IMDB_DATA_DIR=/tmp/imdb +$ bazel run data:gen_vocab -- \ + --output_dir=$IMDB_DATA_DIR \ + --dataset=imdb \ + --imdb_input_dir=/tmp/aclImdb \ + --lowercase=False +``` + +Vocabulary and frequency files will be generated in `$IMDB_DATA_DIR`. + +###  Generate training, validation, and test data + +``` +$ bazel run data:gen_data -- \ + --output_dir=$IMDB_DATA_DIR \ + --dataset=imdb \ + --imdb_input_dir=/tmp/aclImdb \ + --lowercase=False \ + --label_gain=False +``` + +`$IMDB_DATA_DIR` contains TFRecords files. + +### Pretrain IMDB Language Model + +``` +$ PRETRAIN_DIR=/tmp/models/imdb_pretrain +$ bazel run :pretrain -- \ + --train_dir=$PRETRAIN_DIR \ + --data_dir=$IMDB_DATA_DIR \ + --vocab_size=86934 \ + --embedding_dims=256 \ + --rnn_cell_size=1024 \ + --num_candidate_samples=1024 \ + --optimizer=adam \ + --batch_size=256 \ + --learning_rate=0.001 \ + --learning_rate_decay_factor=0.9999 \ + --max_steps=100000 \ + --max_grad_norm=1.0 \ + --num_timesteps=400 \ + --keep_prob_emb=0.5 \ + --normalize_embeddings +``` + +`$PRETRAIN_DIR` contains checkpoints of the pretrained language model. + +### Train classifier + +Most flags stay the same, save for the removal of candidate sampling and the +addition of `pretrained_model_dir`, from which the classifier will load the +pretrained embedding and LSTM variables, and flags related to adversarial +training and classification. + +``` +$ TRAIN_DIR=/tmp/models/imdb_classify +$ bazel run :train_classifier -- \ + --train_dir=$TRAIN_DIR \ + --pretrained_model_dir=$PRETRAIN_DIR \ + --data_dir=$IMDB_DATA_DIR \ + --vocab_size=86934 \ + --embedding_dims=256 \ + --rnn_cell_size=1024 \ + --cl_num_layers=1 \ + --cl_hidden_size=30 \ + --optimizer=adam \ + --batch_size=64 \ + --learning_rate=0.0005 \ + --learning_rate_decay_factor=0.9998 \ + --max_steps=15000 \ + --max_grad_norm=1.0 \ + --num_timesteps=400 \ + --keep_prob_emb=0.5 \ + --normalize_embeddings \ + --adv_training_method=vat +``` + +### Evaluate on test data + +``` +$ EVAL_DIR=/tmp/models/imdb_eval +$ bazel run :evaluate -- \ + --eval_dir=$EVAL_DIR \ + --checkpoint_dir=$TRAIN_DIR \ + --eval_data=test \ + --run_once \ + --num_examples=25000 \ + --data_dir=$IMDB_DATA_DIR \ + --vocab_size=86934 \ + --embedding_dims=256 \ + --rnn_cell_size=1024 \ + --batch_size=256 \ + --num_timesteps=400 \ + --normalize_embeddings +``` + +## Code Overview + +The main entry points are the binaries listed below. Each training binary builds +a `VatxtModel`, defined in `graphs.py`, which in turn uses graph building blocks +defined in `inputs.py` (defines input data reading and parsing), `layers.py` +(defines core model components), and `adversarial_losses.py` (defines +adversarial training losses). The training loop itself is defined in +`train_utils.py`. + +### Binaries + +* Pretraining: `pretrain.py` +* Classifier Training: `train_classifier.py` +* Evaluation: `evaluate.py` + +### Command-Line Flags + +Flags related to distributed training and the training loop itself are defined +in `train_utils.py`. + +Flags related to model hyperparameters are defined in `graphs.py`. + +Flags related to adversarial training are defined in `adversarial_losses.py`. + +Flags particular to each job are defined in the main binary files. + +### Data Generation + +* Vocabulary generation: `gen_vocab.py` +* Data generation: `gen_data.py` + +Command-line flags defined in `document_generators.py` control which dataset is +processed and how. + +## Contact for Issues + +* Ryan Sepassi, @rsepassi diff --git a/adversarial_text/adversarial_losses.py b/adversarial_text/adversarial_losses.py new file mode 100644 index 0000000000000000000000000000000000000000..8cd4656239c1711a25779a116fc26437fcc54dca --- /dev/null +++ b/adversarial_text/adversarial_losses.py @@ -0,0 +1,229 @@ +# Copyright 2017 Google, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Adversarial losses for text models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +flags = tf.app.flags +FLAGS = flags.FLAGS + +# Adversarial and virtual adversarial training parameters. +flags.DEFINE_float('perturb_norm_length', 0.1, + 'Norm length of adversarial perturbation to be ' + 'optimized with validation') + +# Virtual adversarial training parameters +flags.DEFINE_integer('num_power_iteration', 1, 'The number of power iteration') +flags.DEFINE_float('small_constant_for_finite_diff', 1e-3, + 'Small constant for finite difference method') + +# Parameters for building the graph +flags.DEFINE_string('adv_training_method', None, + 'The flag which specifies training method. ' + '"rp" : random perturbation training ' + '"at" : adversarial training ' + '"vat" : virtual adversarial training ' + '"atvat" : at + vat ') +flags.DEFINE_float('adv_reg_coeff', 1.0, + 'Regularization coefficient of adversarial loss.') + + +def random_perturbation_loss(embedded, length, loss_fn): + """Adds noise to embeddings and recomputes classification loss.""" + noise = tf.random_normal(shape=tf.shape(embedded)) + perturb = _scale_l2(_mask_by_length(noise, length), FLAGS.perturb_norm_length) + return loss_fn(embedded + perturb) + + +def adversarial_loss(embedded, loss, loss_fn): + """Adds gradient to embedding and recomputes classification loss.""" + grad, = tf.gradients( + loss, + embedded, + aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) + grad = tf.stop_gradient(grad) + perturb = _scale_l2(grad, FLAGS.perturb_norm_length) + return loss_fn(embedded + perturb) + + +def virtual_adversarial_loss(logits, embedded, inputs, + logits_from_embedding_fn): + """Virtual adversarial loss. + + Computes virtual adversarial perturbation by finite difference method and + power iteration, adds it to the embedding, and computes the KL divergence + between the new logits and the original logits. + + Args: + logits: 2-D float Tensor, [num_timesteps*batch_size, m], where m=1 if + num_classes=2, otherwise m=num_classes. + embedded: 3-D float Tensor, [batch_size, num_timesteps, embedding_dim]. + inputs: VatxtInput. + logits_from_embedding_fn: callable that takes embeddings and returns + classifier logits. + + Returns: + kl: float scalar. + """ + # Stop gradient of logits. See https://arxiv.org/abs/1507.00677 for details. + logits = tf.stop_gradient(logits) + weights = _end_of_seq_mask(inputs.labels) + + # shape(embedded) = (batch_size, num_timesteps, embedding_dim) + d = _mask_by_length(tf.random_normal(shape=tf.shape(embedded)), inputs.length) + + # Perform finite difference method and power iteration. + # See Eq.(8) in the paper http://arxiv.org/pdf/1507.00677.pdf, + # Adding small noise to input and taking gradient with respect to the noise + # corresponds to 1 power iteration. + for _ in xrange(FLAGS.num_power_iteration): + d = _scale_l2(d, FLAGS.small_constant_for_finite_diff) + d_logits = logits_from_embedding_fn(embedded + d) + kl = _kl_divergence_with_logits(logits, d_logits, weights) + d, = tf.gradients( + kl, + d, + aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) + d = tf.stop_gradient(d) + + perturb = _scale_l2( + _mask_by_length(d, inputs.length), FLAGS.perturb_norm_length) + vadv_logits = logits_from_embedding_fn(embedded + perturb) + return _kl_divergence_with_logits(logits, vadv_logits, weights) + + +def random_perturbation_loss_bidir(embedded, length, loss_fn): + """Adds noise to embeddings and recomputes classification loss.""" + noise = [tf.random_normal(shape=tf.shape(emb)) for emb in embedded] + masked = [_mask_by_length(n, length) for n in noise] + scaled = [_scale_l2(m, FLAGS.perturb_norm_length) for m in masked] + return loss_fn([e + s for (e, s) in zip(embedded, scaled)]) + + +def adversarial_loss_bidir(embedded, loss, loss_fn): + """Adds gradient to embeddings and recomputes classification loss.""" + grads = tf.gradients( + loss, + embedded, + aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) + adv_exs = [ + emb + _scale_l2(tf.stop_gradient(g), FLAGS.perturb_norm_length) + for emb, g in zip(embedded, grads) + ] + return loss_fn(adv_exs) + + +def virtual_adversarial_loss_bidir(logits, embedded, inputs, + logits_from_embedding_fn): + """Virtual adversarial loss for bidirectional models.""" + logits = tf.stop_gradient(logits) + f_inputs, _ = inputs + weights = _end_of_seq_mask(f_inputs.labels) + + perturbs = [ + _mask_by_length(tf.random_normal(shape=tf.shape(emb)), f_inputs.length) + for emb in embedded + ] + for _ in xrange(FLAGS.num_power_iteration): + perturbs = [ + _scale_l2(d, FLAGS.small_constant_for_finite_diff) for d in perturbs + ] + d_logits = logits_from_embedding_fn( + [emb + d for (emb, d) in zip(embedded, perturbs)]) + kl = _kl_divergence_with_logits(logits, d_logits, weights) + perturbs = tf.gradients( + kl, + perturbs, + aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) + perturbs = [tf.stop_gradient(d) for d in perturbs] + + perturbs = [ + _scale_l2(_mask_by_length(d, f_inputs.length), FLAGS.perturb_norm_length) + for d in perturbs + ] + vadv_logits = logits_from_embedding_fn( + [emb + d for (emb, d) in zip(embedded, perturbs)]) + return _kl_divergence_with_logits(logits, vadv_logits, weights) + + +def _mask_by_length(t, length): + """Mask t, 3-D [batch, time, dim], by length, 1-D [batch,].""" + maxlen = t.get_shape().as_list()[1] + mask = tf.sequence_mask(length, maxlen=maxlen) + mask = tf.expand_dims(tf.cast(mask, tf.float32), -1) + # shape(mask) = (batch, num_timesteps, 1) + return t * mask + + +def _scale_l2(x, norm_length): + # shape(x) = (batch, num_timesteps, d) + x /= (1e-12 + tf.reduce_max(tf.abs(x), 2, keep_dims=True)) + x_2 = tf.reduce_sum(tf.pow(x, 2), 2, keep_dims=True) + x /= tf.sqrt(1e-6 + x_2) + + return norm_length * x + + +def _end_of_seq_mask(tokens): + """Generate a mask for the EOS token (1.0 on EOS, 0.0 otherwise). + + Args: + tokens: 1-D integer tensor [num_timesteps*batch_size]. Each element is an + id from the vocab. + + Returns: + Float tensor same shape as tokens, whose values are 1.0 on the end of + sequence and 0.0 on the others. + """ + eos_id = FLAGS.vocab_size - 1 + return tf.cast(tf.equal(tokens, eos_id), tf.float32) + + +def _kl_divergence_with_logits(q_logits, p_logits, weights): + """Returns weighted KL divergence between distributions q and p. + + Args: + q_logits: logits for 1st argument of KL divergence shape + [num_timesteps * batch_size, num_classes] if num_classes > 2, and + [num_timesteps * batch_size] if num_classes == 2. + p_logits: logits for 2nd argument of KL divergence with same shape q_logits. + weights: 1-D float tensor with shape [num_timesteps * batch_size]. + Elements should be 1.0 only on end of sequences + + Returns: + KL: float scalar. + """ + # For logistic regression + if FLAGS.num_classes == 2: + q = tf.nn.sigmoid(q_logits) + p = tf.nn.sigmoid(p_logits) + kl = (-tf.nn.sigmoid_cross_entropy_with_logits(logits=q_logits, labels=q) + + tf.nn.sigmoid_cross_entropy_with_logits(logits=p_logits, labels=q)) + + # For softmax regression + else: + q = tf.nn.softmax(q_logits) + p = tf.nn.softmax(p_logits) + kl = tf.reduce_sum(q * (tf.log(q) - tf.log(p)), 1) + + num_labels = tf.reduce_sum(weights) + num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels) + + loss = tf.identity(tf.reduce_sum(weights * kl) / num_labels, name='kl') + return loss diff --git a/adversarial_text/data/BUILD b/adversarial_text/data/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..33d46bcc1643be6964da810dc83e0ce11cdfcc30 --- /dev/null +++ b/adversarial_text/data/BUILD @@ -0,0 +1,41 @@ +package( + default_visibility = [ + "//adversarial_text:__subpackages__", + ], +) + +py_binary( + name = "gen_vocab", + srcs = ["gen_vocab.py"], + deps = [ + ":data_utils", + ":document_generators", + ], +) + +py_binary( + name = "gen_data", + srcs = ["gen_data.py"], + deps = [ + ":data_utils", + ":document_generators", + ], +) + +py_library( + name = "document_generators", + srcs = ["document_generators.py"], +) + +py_library( + name = "data_utils", + srcs = ["data_utils.py"], +) + +py_test( + name = "data_utils_test", + srcs = ["data_utils_test.py"], + deps = [ + ":data_utils", + ], +) diff --git a/adversarial_text/data/data_utils.py b/adversarial_text/data/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1c31ab96d1d8e4573b9657d3792bd880da3d50ec --- /dev/null +++ b/adversarial_text/data/data_utils.py @@ -0,0 +1,326 @@ +# Copyright 2017 Google, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utilities for generating/preprocessing data for adversarial text models.""" + +import operator +import os +import random +import re +import tensorflow as tf + +EOS_TOKEN = '' + +# Data filenames +# Sequence Autoencoder +ALL_SA = 'all_sa.tfrecords' +TRAIN_SA = 'train_sa.tfrecords' +TEST_SA = 'test_sa.tfrecords' +# Language Model +ALL_LM = 'all_lm.tfrecords' +TRAIN_LM = 'train_lm.tfrecords' +TEST_LM = 'test_lm.tfrecords' +# Classification +TRAIN_CLASS = 'train_classification.tfrecords' +TEST_CLASS = 'test_classification.tfrecords' +VALID_CLASS = 'validate_classification.tfrecords' +# LM with bidirectional LSTM +TRAIN_REV_LM = 'train_reverse_lm.tfrecords' +TEST_REV_LM = 'test_reverse_lm.tfrecords' +# Classification with bidirectional LSTM +TRAIN_BD_CLASS = 'train_bidir_classification.tfrecords' +TEST_BD_CLASS = 'test_bidir_classification.tfrecords' +VALID_BD_CLASS = 'validate_bidir_classification.tfrecords' + + +class ShufflingTFRecordWriter(object): + """Thin wrapper around TFRecordWriter that shuffles records.""" + + def __init__(self, path): + self._path = path + self._records = [] + self._closed = False + + def write(self, record): + assert not self._closed + self._records.append(record) + + def close(self): + assert not self._closed + random.shuffle(self._records) + with tf.python_io.TFRecordWriter(self._path) as f: + for record in self._records: + f.write(record) + self._closed = True + + def __enter__(self): + return self + + def __exit__(self, unused_type, unused_value, unused_traceback): + self.close() + + +class Timestep(object): + """Represents a single timestep in a SequenceWrapper.""" + + def __init__(self, token, label, weight, multivalent_tokens=False): + """Constructs Timestep from empty Features.""" + self._token = token + self._label = label + self._weight = weight + self._multivalent_tokens = multivalent_tokens + self._fill_with_defaults() + + @property + def token(self): + if self._multivalent_tokens: + raise TypeError('Timestep may contain multiple values; use `tokens`') + return self._token.int64_list.value[0] + + @property + def tokens(self): + return self._token.int64_list.value + + @property + def label(self): + return self._label.int64_list.value[0] + + @property + def weight(self): + return self._weight.float_list.value[0] + + def set_token(self, token): + if self._multivalent_tokens: + raise TypeError('Timestep may contain multiple values; use `add_token`') + self._token.int64_list.value[0] = token + return self + + def add_token(self, token): + self._token.int64_list.value.append(token) + return self + + def set_label(self, label): + self._label.int64_list.value[0] = label + return self + + def set_weight(self, weight): + self._weight.float_list.value[0] = weight + return self + + def copy_from(self, timestep): + self.set_token(timestep.token).set_label(timestep.label).set_weight( + timestep.weight) + return self + + def _fill_with_defaults(self): + if not self._multivalent_tokens: + self._token.int64_list.value.append(0) + self._label.int64_list.value.append(0) + self._weight.float_list.value.append(0.0) + + +class SequenceWrapper(object): + """Wrapper around tf.SequenceExample.""" + + F_TOKEN_ID = 'token_id' + F_LABEL = 'label' + F_WEIGHT = 'weight' + + def __init__(self, multivalent_tokens=False): + self._seq = tf.train.SequenceExample() + self._flist = self._seq.feature_lists.feature_list + self._timesteps = [] + self._multivalent_tokens = multivalent_tokens + + @property + def seq(self): + return self._seq + + @property + def multivalent_tokens(self): + return self._multivalent_tokens + + @property + def _tokens(self): + return self._flist[SequenceWrapper.F_TOKEN_ID].feature + + @property + def _labels(self): + return self._flist[SequenceWrapper.F_LABEL].feature + + @property + def _weights(self): + return self._flist[SequenceWrapper.F_WEIGHT].feature + + def add_timestep(self): + timestep = Timestep( + self._tokens.add(), + self._labels.add(), + self._weights.add(), + multivalent_tokens=self._multivalent_tokens) + self._timesteps.append(timestep) + return timestep + + def __iter__(self): + for timestep in self._timesteps: + yield timestep + + def __len__(self): + return len(self._timesteps) + + def __getitem__(self, idx): + return self._timesteps[idx] + + +def build_reverse_sequence(seq): + """Builds a sequence that is the reverse of the input sequence.""" + reverse_seq = SequenceWrapper() + + # Copy all but last timestep + for timestep in reversed(seq[:-1]): + reverse_seq.add_timestep().copy_from(timestep) + + # Copy final timestep + reverse_seq.add_timestep().copy_from(seq[-1]) + + return reverse_seq + + +def build_bidirectional_seq(seq, rev_seq): + bidir_seq = SequenceWrapper(multivalent_tokens=True) + for forward_ts, reverse_ts in zip(seq, rev_seq): + bidir_seq.add_timestep().add_token(forward_ts.token).add_token( + reverse_ts.token) + + return bidir_seq + + +def build_lm_sequence(seq): + """Builds language model sequence from input sequence. + + Args: + seq: SequenceWrapper. + + Returns: + SequenceWrapper with `seq` tokens copied over to output sequence tokens and + labels (offset by 1, i.e. predict next token) with weights set to 1.0. + """ + lm_seq = SequenceWrapper() + for i, timestep in enumerate(seq[:-1]): + lm_seq.add_timestep().set_token(timestep.token).set_label( + seq[i + 1].token).set_weight(1.0) + + return lm_seq + + +def build_seq_ae_sequence(seq): + """Builds seq_ae sequence from input sequence. + + Args: + seq: SequenceWrapper. + + Returns: + SequenceWrapper with `seq` inputs copied and concatenated, and with labels + copied in on the right-hand (i.e. decoder) side with weights set to 1.0. + The new sequence will have length `len(seq) * 2 - 1`, as the last timestep + of the encoder section and the first step of the decoder section will + overlap. + """ + seq_ae_seq = SequenceWrapper() + + for i in range(len(seq) * 2 - 1): + ts = seq_ae_seq.add_timestep() + + if i < len(seq) - 1: + # Encoder + ts.set_token(seq[i].token) + elif i == len(seq) - 1: + # Transition step + ts.set_token(seq[i].token) + ts.set_label(seq[0].token) + ts.set_weight(1.0) + else: + # Decoder + ts.set_token(seq[i % len(seq)].token) + ts.set_label(seq[(i + 1) % len(seq)].token) + ts.set_weight(1.0) + + return seq_ae_seq + + +def build_labeled_sequence(seq, class_label, label_gain=False): + """Builds labeled sequence from input sequence. + + Args: + seq: SequenceWrapper. + class_label: bool. + label_gain: bool. If True, class_label will be put on every timestep and + weight will increase linearly from 0 to 1. + + Returns: + SequenceWrapper with `seq` copied in and `class_label` added as label to + final timestep. + """ + label_seq = SequenceWrapper(multivalent_tokens=seq.multivalent_tokens) + + # Copy sequence without labels + seq_len = len(seq) + final_timestep = None + for i, timestep in enumerate(seq): + label_timestep = label_seq.add_timestep() + if seq.multivalent_tokens: + for token in timestep.tokens: + label_timestep.add_token(token) + else: + label_timestep.set_token(timestep.token) + if label_gain: + label_timestep.set_label(int(class_label)) + weight = 1.0 if seq_len < 2 else float(i) / (seq_len - 1) + label_timestep.set_weight(weight) + if i == (seq_len - 1): + final_timestep = label_timestep + + # Edit final timestep to have class label and weight = 1. + final_timestep.set_label(int(class_label)).set_weight(1.0) + + return label_seq + + +def split_by_punct(segment): + """Splits str segment by punctuation, filters our empties and spaces.""" + return [s for s in re.split(r'\W+', segment) if s and not s.isspace()] + + +def sort_vocab_by_frequency(vocab_freq_map): + """Sorts vocab_freq_map by count. + + Args: + vocab_freq_map: dict, vocabulary terms with counts. + + Returns: + list> sorted by count, descending. + """ + return sorted( + vocab_freq_map.items(), key=operator.itemgetter(1), reverse=True) + + +def write_vocab_and_frequency(ordered_vocab_freqs, output_dir): + """Writes ordered_vocab_freqs into vocab.txt and vocab_freq.txt.""" + tf.gfile.MakeDirs(output_dir) + with open(os.path.join(output_dir, 'vocab.txt'), 'w') as vocab_f: + with open(os.path.join(output_dir, 'vocab_freq.txt'), 'w') as freq_f: + for word, freq in ordered_vocab_freqs: + vocab_f.write('{}\n'.format(word)) + freq_f.write('{}\n'.format(freq)) diff --git a/adversarial_text/data/data_utils_test.py b/adversarial_text/data/data_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..614b12953e77f9d66503f1d1a1e0d81b98e84a14 --- /dev/null +++ b/adversarial_text/data/data_utils_test.py @@ -0,0 +1,192 @@ +# Copyright 2017 Google, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for data_utils.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from adversarial_text.data import data_utils + +data = data_utils + + +class SequenceWrapperTest(tf.test.TestCase): + + def testDefaultTimesteps(self): + seq = data.SequenceWrapper() + t1 = seq.add_timestep() + _ = seq.add_timestep() + self.assertEqual(len(seq), 2) + + self.assertEqual(t1.weight, 0.0) + self.assertEqual(t1.label, 0) + self.assertEqual(t1.token, 0) + + def testSettersAndGetters(self): + ts = data.SequenceWrapper().add_timestep() + ts.set_token(3) + ts.set_label(4) + ts.set_weight(2.0) + self.assertEqual(ts.token, 3) + self.assertEqual(ts.label, 4) + self.assertEqual(ts.weight, 2.0) + + def testTimestepIteration(self): + seq = data.SequenceWrapper() + seq.add_timestep().set_token(0) + seq.add_timestep().set_token(1) + seq.add_timestep().set_token(2) + for i, ts in enumerate(seq): + self.assertEqual(ts.token, i) + + def testFillsSequenceExampleCorrectly(self): + seq = data.SequenceWrapper() + seq.add_timestep().set_token(1).set_label(2).set_weight(3.0) + seq.add_timestep().set_token(10).set_label(20).set_weight(30.0) + + seq_ex = seq.seq + fl = seq_ex.feature_lists.feature_list + fl_token = fl[data.SequenceWrapper.F_TOKEN_ID].feature + fl_label = fl[data.SequenceWrapper.F_LABEL].feature + fl_weight = fl[data.SequenceWrapper.F_WEIGHT].feature + _ = [self.assertEqual(len(f), 2) for f in [fl_token, fl_label, fl_weight]] + self.assertAllEqual([f.int64_list.value[0] for f in fl_token], [1, 10]) + self.assertAllEqual([f.int64_list.value[0] for f in fl_label], [2, 20]) + self.assertAllEqual([f.float_list.value[0] for f in fl_weight], [3.0, 30.0]) + + +class DataUtilsTest(tf.test.TestCase): + + def testSplitByPunct(self): + output = data.split_by_punct( + 'hello! world, i\'ve been\nwaiting\tfor\ryou for.a long time') + expected = [ + 'hello', 'world', 'i', 've', 'been', 'waiting', 'for', 'you', 'for', + 'a', 'long', 'time' + ] + self.assertListEqual(output, expected) + + def _buildDummySequence(self): + seq = data.SequenceWrapper() + for i in range(10): + seq.add_timestep().set_token(i) + return seq + + def testBuildLMSeq(self): + seq = self._buildDummySequence() + lm_seq = data.build_lm_sequence(seq) + for i, ts in enumerate(lm_seq): + self.assertEqual(ts.token, i) + self.assertEqual(ts.label, i + 1) + self.assertEqual(ts.weight, 1.0) + + def testBuildSAESeq(self): + seq = self._buildDummySequence() + sa_seq = data.build_seq_ae_sequence(seq) + + self.assertEqual(len(sa_seq), len(seq) * 2 - 1) + + # Tokens should be sequence twice, minus the EOS token at the end + for i, ts in enumerate(sa_seq): + self.assertEqual(ts.token, seq[i % 10].token) + + # Weights should be len-1 0.0's and len 1.0's. + for i in range(len(seq) - 1): + self.assertEqual(sa_seq[i].weight, 0.0) + for i in range(len(seq) - 1, len(sa_seq)): + self.assertEqual(sa_seq[i].weight, 1.0) + + # Labels should be len-1 0's, and then the sequence + for i in range(len(seq) - 1): + self.assertEqual(sa_seq[i].label, 0) + for i in range(len(seq) - 1, len(sa_seq)): + self.assertEqual(sa_seq[i].label, seq[i - (len(seq) - 1)].token) + + def testBuildLabelSeq(self): + seq = self._buildDummySequence() + eos_id = len(seq) - 1 + label_seq = data.build_labeled_sequence(seq, True) + for i, ts in enumerate(label_seq[:-1]): + self.assertEqual(ts.token, i) + self.assertEqual(ts.label, 0) + self.assertEqual(ts.weight, 0.0) + + final_timestep = label_seq[-1] + self.assertEqual(final_timestep.token, eos_id) + self.assertEqual(final_timestep.label, 1) + self.assertEqual(final_timestep.weight, 1.0) + + def testBuildBidirLabelSeq(self): + seq = self._buildDummySequence() + reverse_seq = data.build_reverse_sequence(seq) + bidir_seq = data.build_bidirectional_seq(seq, reverse_seq) + label_seq = data.build_labeled_sequence(bidir_seq, True) + + for (i, ts), j in zip( + enumerate(label_seq[:-1]), reversed(range(len(seq) - 1))): + self.assertAllEqual(ts.tokens, [i, j]) + self.assertEqual(ts.label, 0) + self.assertEqual(ts.weight, 0.0) + + final_timestep = label_seq[-1] + eos_id = len(seq) - 1 + self.assertAllEqual(final_timestep.tokens, [eos_id, eos_id]) + self.assertEqual(final_timestep.label, 1) + self.assertEqual(final_timestep.weight, 1.0) + + def testReverseSeq(self): + seq = self._buildDummySequence() + reverse_seq = data.build_reverse_sequence(seq) + for i, ts in enumerate(reversed(reverse_seq[:-1])): + self.assertEqual(ts.token, i) + self.assertEqual(ts.label, 0) + self.assertEqual(ts.weight, 0.0) + + final_timestep = reverse_seq[-1] + eos_id = len(seq) - 1 + self.assertEqual(final_timestep.token, eos_id) + self.assertEqual(final_timestep.label, 0) + self.assertEqual(final_timestep.weight, 0.0) + + def testBidirSeq(self): + seq = self._buildDummySequence() + reverse_seq = data.build_reverse_sequence(seq) + bidir_seq = data.build_bidirectional_seq(seq, reverse_seq) + for (i, ts), j in zip( + enumerate(bidir_seq[:-1]), reversed(range(len(seq) - 1))): + self.assertAllEqual(ts.tokens, [i, j]) + self.assertEqual(ts.label, 0) + self.assertEqual(ts.weight, 0.0) + + final_timestep = bidir_seq[-1] + eos_id = len(seq) - 1 + self.assertAllEqual(final_timestep.tokens, [eos_id, eos_id]) + self.assertEqual(final_timestep.label, 0) + self.assertEqual(final_timestep.weight, 0.0) + + def testLabelGain(self): + seq = self._buildDummySequence() + label_seq = data.build_labeled_sequence(seq, True, label_gain=True) + for i, ts in enumerate(label_seq): + self.assertEqual(ts.token, i) + self.assertEqual(ts.label, 1) + self.assertNear(ts.weight, float(i) / (len(seq) - 1), 1e-3) + + +if __name__ == '__main__': + tf.test.main() diff --git a/adversarial_text/data/document_generators.py b/adversarial_text/data/document_generators.py new file mode 100644 index 0000000000000000000000000000000000000000..990dae775fe4218de7cfa84445dae1f0e3bc1eda --- /dev/null +++ b/adversarial_text/data/document_generators.py @@ -0,0 +1,370 @@ +# Copyright 2017 Google, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Input readers and document/token generators for datasets.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple +import csv +import os +import random + +import tensorflow as tf + +from adversarial_text.data import data_utils + +flags = tf.app.flags +FLAGS = flags.FLAGS + +flags.DEFINE_string('dataset', '', 'Which dataset to generate data for') + +# Preprocessing config +flags.DEFINE_boolean('output_unigrams', True, 'Whether to output unigrams.') +flags.DEFINE_boolean('output_bigrams', False, 'Whether to output bigrams.') +flags.DEFINE_boolean('output_char', False, 'Whether to output characters.') +flags.DEFINE_boolean('lowercase', True, 'Whether to lowercase document terms.') + +# IMDB +flags.DEFINE_string('imdb_input_dir', '', 'The input directory containing the ' + 'IMDB sentiment dataset.') +flags.DEFINE_integer('imdb_validation_pos_start_id', 10621, 'File id of the ' + 'first file in the pos sentiment validation set.') +flags.DEFINE_integer('imdb_validation_neg_start_id', 10625, 'File id of the ' + 'first file in the neg sentiment validation set.') + +# DBpedia +flags.DEFINE_string('dbpedia_input_dir', '', + 'Path to DBpedia directory containing train.csv and ' + 'test.csv.') + +# Reuters Corpus (rcv1) +flags.DEFINE_string('rcv1_input_dir', '', + 'Path to rcv1 directory containing train.csv, unlab.csv, ' + 'and test.csv.') + +# Rotten Tomatoes +flags.DEFINE_string('rt_input_dir', '', + 'The Rotten Tomatoes dataset input directory.') + + +# The amazon reviews input file to use in either the RT or IMDB datasets. +flags.DEFINE_string('amazon_unlabeled_input_file', '', + 'The unlabeled Amazon Reviews dataset input file. If set, ' + 'the input file is used to augment RT and IMDB vocab.') + +Document = namedtuple('Document', + 'content is_validation is_test label add_tokens') + + +def documents(dataset='train', + include_unlabeled=False, + include_validation=False): + """Generates Documents based on FLAGS.dataset. + + Args: + dataset: str, identifies folder within IMDB data directory, test or train. + include_unlabeled: bool, whether to include the unsup directory. Only valid + when dataset=train. + include_validation: bool, whether to include validation data. + + Yields: + Document + + Raises: + ValueError: if include_unlabeled is true but dataset is not 'train' + """ + + if include_unlabeled and dataset != 'train': + raise ValueError('If include_unlabeled=True, must use train dataset') + + # Set the random seed so that we have the same validation set when running + # gen_data and gen_vocab. + random.seed(302) + + ds = FLAGS.dataset + if ds == 'imdb': + docs_gen = imdb_documents + elif ds == 'dbpedia': + docs_gen = dbpedia_documents + elif ds == 'rcv1': + docs_gen = rcv1_documents + elif ds == 'rt': + docs_gen = rt_documents + else: + raise ValueError('Unrecognized dataset %s' % FLAGS.dataset) + + for doc in docs_gen(dataset, include_unlabeled, include_validation): + yield doc + + +def tokens(doc): + """Given a Document, produces character or word tokens. + + Tokens can be either characters, or word-level tokens (unigrams and/or + bigrams). + + Args: + doc: Document to produce tokens from. + + Yields: + token + + Raises: + ValueError: if all FLAGS.{output_unigrams, output_bigrams, output_char} + are False. + """ + if not (FLAGS.output_unigrams or FLAGS.output_bigrams or FLAGS.output_char): + raise ValueError( + 'At least one of {FLAGS.output_unigrams, FLAGS.output_bigrams, ' + 'FLAGS.output_char} must be true') + + content = doc.content.strip() + if FLAGS.lowercase: + content = content.lower() + + if FLAGS.output_char: + for char in content: + yield char + + else: + tokens_ = data_utils.split_by_punct(content) + for i, token in enumerate(tokens_): + if FLAGS.output_unigrams: + yield token + + if FLAGS.output_bigrams: + previous_token = (tokens_[i - 1] if i > 0 else data_utils.EOS_TOKEN) + bigram = '_'.join([previous_token, token]) + yield bigram + if (i + 1) == len(tokens_): + bigram = '_'.join([token, data_utils.EOS_TOKEN]) + yield bigram + + +def imdb_documents(dataset='train', + include_unlabeled=False, + include_validation=False): + """Generates Documents for IMDB dataset. + + Data from http://ai.stanford.edu/~amaas/data/sentiment/ + + Args: + dataset: str, identifies folder within IMDB data directory, test or train. + include_unlabeled: bool, whether to include the unsup directory. Only valid + when dataset=train. + include_validation: bool, whether to include validation data. + + Yields: + Document + + Raises: + ValueError: if FLAGS.imdb_input_dir is empty. + """ + if not FLAGS.imdb_input_dir: + raise ValueError('Must provide FLAGS.imdb_input_dir') + + tf.logging.info('Generating IMDB documents...') + + def check_is_validation(filename, class_label): + if class_label is None: + return False + file_idx = int(filename.split('_')[0]) + is_pos_valid = (class_label and + file_idx >= FLAGS.imdb_validation_pos_start_id) + is_neg_valid = (not class_label and + file_idx >= FLAGS.imdb_validation_neg_start_id) + return is_pos_valid or is_neg_valid + + dirs = [(dataset + '/pos', True), (dataset + '/neg', False)] + if include_unlabeled: + dirs.append(('train/unsup', None)) + + for d, class_label in dirs: + for filename in os.listdir(os.path.join(FLAGS.imdb_input_dir, d)): + is_validation = check_is_validation(filename, class_label) + if is_validation and not include_validation: + continue + + with open(os.path.join(FLAGS.imdb_input_dir, d, filename)) as imdb_f: + content = imdb_f.read() + yield Document( + content=content, + is_validation=is_validation, + is_test=False, + label=class_label, + add_tokens=True) + + if FLAGS.amazon_unlabeled_input_file and include_unlabeled: + with open(FLAGS.amazon_unlabeled_input_file) as rt_f: + for content in rt_f: + yield Document(content=content, is_validation=False, is_test=False, + label=None, add_tokens=False) + + +def dbpedia_documents(dataset='train', + include_unlabeled=False, + include_validation=False): + """Generates Documents for DBpedia dataset. + + Dataset linked to at https://github.com/zhangxiangxiao/Crepe. + + Args: + dataset: str, identifies the csv file within the DBpedia data directory, + test or train. + include_unlabeled: bool, unused. + include_validation: bool, whether to include validation data, which is a + randomly selected 10% of the data. + + Yields: + Document + + Raises: + ValueError: if FLAGS.dbpedia_input_dir is empty. + """ + del include_unlabeled + + if not FLAGS.dbpedia_input_dir: + raise ValueError('Must provide FLAGS.dbpedia_input_dir') + + tf.logging.info('Generating DBpedia documents...') + + with open(os.path.join(FLAGS.dbpedia_input_dir, dataset + '.csv')) as db_f: + reader = csv.reader(db_f) + for row in reader: + # 10% of the data is randomly held out + is_validation = random.randint(1, 10) == 1 + if is_validation and not include_validation: + continue + + content = row[1] + ' ' + row[2] + yield Document( + content=content, + is_validation=is_validation, + is_test=False, + label=int(row[0]), + add_tokens=True) + + +def rcv1_documents(dataset='train', + include_unlabeled=True, + include_validation=False): + # pylint:disable=line-too-long + """Generates Documents for Reuters Corpus (rcv1) dataset. + + Dataset described at http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/lyrl2004_rcv1v2_README.htm + + Args: + dataset: str, identifies the csv file within the rcv1 data directory. + include_unlabeled: bool, whether to include the unlab file. Only valid + when dataset=train. + include_validation: bool, whether to include validation data, which is a + randomly selected 10% of the data. + + Yields: + Document + + Raises: + ValueError: if FLAGS.rcv1_input_dir is empty. + """ + # pylint:enable=line-too-long + + if not FLAGS.rcv1_input_dir: + raise ValueError('Must provide FLAGS.rcv1_input_dir') + + tf.logging.info('Generating rcv1 documents...') + + datasets = [dataset] + if include_unlabeled: + if dataset == 'train': + datasets.append('unlab') + for dset in datasets: + with open(os.path.join(FLAGS.rcv1_input_dir, dset + '.csv')) as db_f: + reader = csv.reader(db_f) + for row in reader: + # 10% of the data is randomly held out + is_validation = random.randint(1, 10) == 1 + if is_validation and not include_validation: + continue + + content = row[1] + yield Document( + content=content, + is_validation=is_validation, + is_test=False, + label=int(row[0]), + add_tokens=True) + + +def rt_documents(dataset='train', + include_unlabeled=True, + include_validation=False): + # pylint:disable=line-too-long + """Generates Documents for the Rotten Tomatoes dataset. + + Dataset available at http://www.cs.cornell.edu/people/pabo/movie-review-data/ + In this dataset, amazon reviews are used for the unlabeled data. + + Args: + dataset: str, identifies the data subdirectory. + include_unlabeled: bool, whether to include the unlabeled data. Only valid + when dataset=train. + include_validation: bool, whether to include validation data, which is a + randomly selected 10% of the data. + + Yields: + Document + + Raises: + ValueError: if FLAGS.rt_input_dir is empty. + """ + # pylint:enable=line-too-long + + if not FLAGS.rt_input_dir: + raise ValueError('Must provide FLAGS.rt_input_dir') + + tf.logging.info('Generating rt documents...') + + data_files = [] + input_filenames = os.listdir(FLAGS.rt_input_dir) + for inp_fname in input_filenames: + if inp_fname.endswith('.pos'): + data_files.append((os.path.join(FLAGS.rt_input_dir, inp_fname), True)) + elif inp_fname.endswith('.neg'): + data_files.append((os.path.join(FLAGS.rt_input_dir, inp_fname), False)) + if include_unlabeled and FLAGS.amazon_unlabeled_input_file: + data_files.append((FLAGS.amazon_unlabeled_input_file, None)) + + for filename, class_label in data_files: + with open(filename) as rt_f: + for content in rt_f: + if class_label is None: + # Process Amazon Review data for unlabeled dataset + if content.startswith('review/text'): + yield Document(content=content, is_validation=False, + is_test=False, label=None, add_tokens=False) + else: + # 10% of the data is randomly held out for the validation set and + # another 10% of it is randomly held out for the test set + random_int = random.randint(1, 10) + is_validation = random_int == 1 + is_test = random_int == 2 + if (is_test and dataset != 'test') or ( + is_validation and not include_validation): + continue + + yield Document(content=content, is_validation=is_validation, + is_test=is_test, label=class_label, add_tokens=True) diff --git a/adversarial_text/data/gen_data.py b/adversarial_text/data/gen_data.py new file mode 100644 index 0000000000000000000000000000000000000000..0631de8e77520c9ec00179fa44e5425de7ec2cb2 --- /dev/null +++ b/adversarial_text/data/gen_data.py @@ -0,0 +1,215 @@ +# Copyright 2017 Google, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Create TFRecord files of SequenceExample protos from dataset. + +Constructs 3 datasets: + 1. Labeled data for the LSTM classification model, optionally with label gain. + "*_classification.tfrecords" (for both unidirectional and bidirectional + models). + 2. Data for the unsupervised LM-LSTM model that predicts the next token. + "*_lm.tfrecords" (generates forward and reverse data). + 3. Data for the unsupervised SA-LSTM model that uses Seq2Seq. + "*_sa.tfrecords". +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import string + +import tensorflow as tf + +from adversarial_text.data import data_utils +from adversarial_text.data import document_generators + +data = data_utils +flags = tf.app.flags +FLAGS = flags.FLAGS + +# Flags for input data are in document_generators.py +flags.DEFINE_string('vocab_file', '', 'Path to the vocabulary file. Defaults ' + 'to FLAGS.output_dir/vocab.txt.') +flags.DEFINE_string('output_dir', '', 'Path to save tfrecords.') + +# Config +flags.DEFINE_boolean('label_gain', False, + 'Enable linear label gain. If True, sentiment label will ' + 'be included at each timestep with linear weight ' + 'increase.') + + +def build_shuffling_tf_record_writer(fname): + return data.ShufflingTFRecordWriter(os.path.join(FLAGS.output_dir, fname)) + + +def build_tf_record_writer(fname): + return tf.python_io.TFRecordWriter(os.path.join(FLAGS.output_dir, fname)) + + +def build_input_sequence(doc, vocab_ids): + """Builds input sequence from file. + + Splits lines on whitespace. Treats punctuation as whitespace. For word-level + sequences, only keeps terms that are in the vocab. + + Terms are added as token in the SequenceExample. The EOS_TOKEN is also + appended. Label and weight features are set to 0. + + Args: + doc: Document (defined in `document_generators`) from which to build the + sequence. + vocab_ids: dict. + + Returns: + SequenceExampleWrapper. + """ + seq = data.SequenceWrapper() + for token in document_generators.tokens(doc): + if token in vocab_ids: + seq.add_timestep().set_token(vocab_ids[token]) + + # Add EOS token to end + seq.add_timestep().set_token(vocab_ids[data.EOS_TOKEN]) + + return seq + + +def make_vocab_ids(vocab_filename): + if FLAGS.output_char: + ret = dict([(char, i) for i, char in enumerate(string.printable)]) + ret[data.EOS_TOKEN] = len(string.printable) + return ret + else: + with open(vocab_filename) as vocab_f: + return dict([(line.strip(), i) for i, line in enumerate(vocab_f)]) + + +def generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all): + """Generates training data.""" + + # Construct training data writers + writer_lm = build_shuffling_tf_record_writer(data.TRAIN_LM) + writer_seq_ae = build_shuffling_tf_record_writer(data.TRAIN_SA) + writer_class = build_shuffling_tf_record_writer(data.TRAIN_CLASS) + writer_valid_class = build_tf_record_writer(data.VALID_CLASS) + writer_rev_lm = build_shuffling_tf_record_writer(data.TRAIN_REV_LM) + writer_bd_class = build_shuffling_tf_record_writer(data.TRAIN_BD_CLASS) + writer_bd_valid_class = build_shuffling_tf_record_writer(data.VALID_BD_CLASS) + + for doc in document_generators.documents( + dataset='train', include_unlabeled=True, include_validation=True): + input_seq = build_input_sequence(doc, vocab_ids) + if len(input_seq) < 2: + continue + rev_seq = data.build_reverse_sequence(input_seq) + lm_seq = data.build_lm_sequence(input_seq) + rev_lm_seq = data.build_lm_sequence(rev_seq) + seq_ae_seq = data.build_seq_ae_sequence(input_seq) + if doc.label is not None: + # Used for sentiment classification. + label_seq = data.build_labeled_sequence( + input_seq, + doc.label, + label_gain=(FLAGS.label_gain and not doc.is_validation)) + bd_label_seq = data.build_labeled_sequence( + data.build_bidirectional_seq(input_seq, rev_seq), + doc.label, + label_gain=(FLAGS.label_gain and not doc.is_validation)) + class_writer = writer_valid_class if doc.is_validation else writer_class + bd_class_writer = (writer_bd_valid_class + if doc.is_validation else writer_bd_class) + class_writer.write(label_seq.seq.SerializeToString()) + bd_class_writer.write(bd_label_seq.seq.SerializeToString()) + + # Write + lm_seq_ser = lm_seq.seq.SerializeToString() + seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString() + writer_lm_all.write(lm_seq_ser) + writer_seq_ae_all.write(seq_ae_seq_ser) + if not doc.is_validation: + writer_lm.write(lm_seq_ser) + writer_rev_lm.write(rev_lm_seq.seq.SerializeToString()) + writer_seq_ae.write(seq_ae_seq_ser) + + # Close writers + writer_lm.close() + writer_seq_ae.close() + writer_class.close() + writer_valid_class.close() + writer_rev_lm.close() + writer_bd_class.close() + writer_bd_valid_class.close() + + +def generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all): + """Generates test data.""" + # Construct test data writers + writer_lm = build_shuffling_tf_record_writer(data.TEST_LM) + writer_rev_lm = build_shuffling_tf_record_writer(data.TEST_REV_LM) + writer_seq_ae = build_shuffling_tf_record_writer(data.TEST_SA) + writer_class = build_tf_record_writer(data.TEST_CLASS) + writer_bd_class = build_shuffling_tf_record_writer(data.TEST_BD_CLASS) + + for doc in document_generators.documents( + dataset='test', include_unlabeled=False, include_validation=True): + input_seq = build_input_sequence(doc, vocab_ids) + if len(input_seq) < 2: + continue + rev_seq = data.build_reverse_sequence(input_seq) + lm_seq = data.build_lm_sequence(input_seq) + rev_lm_seq = data.build_lm_sequence(rev_seq) + seq_ae_seq = data.build_seq_ae_sequence(input_seq) + label_seq = data.build_labeled_sequence(input_seq, doc.label) + bd_label_seq = data.build_labeled_sequence( + data.build_bidirectional_seq(input_seq, rev_seq), doc.label) + + # Write + writer_class.write(label_seq.seq.SerializeToString()) + writer_bd_class.write(bd_label_seq.seq.SerializeToString()) + lm_seq_ser = lm_seq.seq.SerializeToString() + seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString() + writer_lm.write(lm_seq_ser) + writer_rev_lm.write(rev_lm_seq.seq.SerializeToString()) + writer_seq_ae.write(seq_ae_seq_ser) + writer_lm_all.write(lm_seq_ser) + writer_seq_ae_all.write(seq_ae_seq_ser) + + # Close test writers + writer_lm.close() + writer_rev_lm.close() + writer_seq_ae.close() + writer_class.close() + writer_bd_class.close() + + +def main(_): + tf.logging.info('Assigning vocabulary ids...') + vocab_ids = make_vocab_ids( + FLAGS.vocab_file or os.path.join(FLAGS.output_dir, 'vocab.txt')) + + with build_shuffling_tf_record_writer(data.ALL_LM) as writer_lm_all: + with build_shuffling_tf_record_writer(data.ALL_SA) as writer_seq_ae_all: + + tf.logging.info('Generating training data...') + generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all) + + tf.logging.info('Generating test data...') + generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all) + + +if __name__ == '__main__': + tf.app.run() diff --git a/adversarial_text/data/gen_vocab.py b/adversarial_text/data/gen_vocab.py new file mode 100644 index 0000000000000000000000000000000000000000..43a8688fa95fd4fa917894e4a709df2303bdb1e0 --- /dev/null +++ b/adversarial_text/data/gen_vocab.py @@ -0,0 +1,98 @@ +# Copyright 2017 Google, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Generates vocabulary and term frequency files for datasets.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict + +import tensorflow as tf + +from adversarial_text.data import data_utils +from adversarial_text.data import document_generators + +flags = tf.app.flags +FLAGS = flags.FLAGS + +# Flags controlling input are in document_generators.py + +flags.DEFINE_string('output_dir', '', + 'Path to save vocab.txt and vocab_freq.txt.') + +flags.DEFINE_boolean('use_unlabeled', True, 'Whether to use the ' + 'unlabeled sentiment dataset in the vocabulary.') +flags.DEFINE_boolean('include_validation', False, 'Whether to include the ' + 'validation set in the vocabulary.') +flags.DEFINE_integer('doc_count_threshold', 1, 'The minimum number of ' + 'documents a word or bigram should occur in to keep ' + 'it in the vocabulary.') + +MAX_VOCAB_SIZE = 100 * 1000 + + +def fill_vocab_from_doc(doc, vocab_freqs, doc_counts): + """Fills vocabulary and doc counts with tokens from doc. + + Args: + doc: Document to read tokens from. + vocab_freqs: dict + doc_counts: dict + + Returns: + None + """ + doc_seen = set() + + for token in document_generators.tokens(doc): + if doc.add_tokens or token in vocab_freqs: + vocab_freqs[token] += 1 + if token not in doc_seen: + doc_counts[token] += 1 + doc_seen.add(token) + + +def main(_): + vocab_freqs = defaultdict(int) + doc_counts = defaultdict(int) + + # Fill vocabulary frequencies map and document counts map + for doc in document_generators.documents( + dataset='train', + include_unlabeled=FLAGS.use_unlabeled, + include_validation=FLAGS.include_validation): + fill_vocab_from_doc(doc, vocab_freqs, doc_counts) + + # Filter out low-occurring terms + vocab_freqs = dict((term, freq) for term, freq in vocab_freqs.iteritems() + if doc_counts[term] > FLAGS.doc_count_threshold) + + # Sort by frequency + ordered_vocab_freqs = data_utils.sort_vocab_by_frequency(vocab_freqs) + + # Limit vocab size + ordered_vocab_freqs = ordered_vocab_freqs[:MAX_VOCAB_SIZE] + + # Add EOS token + ordered_vocab_freqs.append((data_utils.EOS_TOKEN, 1)) + + # Write + tf.gfile.MakeDirs(FLAGS.output_dir) + data_utils.write_vocab_and_frequency(ordered_vocab_freqs, FLAGS.output_dir) + + +if __name__ == '__main__': + tf.app.run() diff --git a/adversarial_text/evaluate.py b/adversarial_text/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..7c68f88cf33703b5830d16ef14e8d206d3de40ac --- /dev/null +++ b/adversarial_text/evaluate.py @@ -0,0 +1,129 @@ +# Copyright 2017 Google, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Evaluates text classification model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import time + +import tensorflow as tf + +import graphs + +flags = tf.app.flags +FLAGS = flags.FLAGS + +flags.DEFINE_string('master', '', + 'BNS name prefix of the Tensorflow eval master, ' + 'or "local".') +flags.DEFINE_string('eval_dir', '/tmp/text_eval', + 'Directory where to write event logs.') +flags.DEFINE_string('eval_data', 'test', 'Specify which dataset is used. ' + '("train", "valid", "test") ') + +flags.DEFINE_string('checkpoint_dir', '/tmp/text_train', + 'Directory where to read model checkpoints.') +flags.DEFINE_integer('eval_interval_secs', 60, 'How often to run the eval.') +flags.DEFINE_integer('num_examples', 32, 'Number of examples to run.') +flags.DEFINE_bool('run_once', False, 'Whether to run eval only once.') + + +def restore_from_checkpoint(sess, saver): + """Restore model from checkpoint. + + Args: + sess: Session. + saver: Saver for restoring the checkpoint. + + Returns: + bool: Whether the checkpoint was found and restored + """ + ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) + if not ckpt or not ckpt.model_checkpoint_path: + tf.logging.info('No checkpoint found at %s', FLAGS.checkpoint_dir) + return False + + saver.restore(sess, ckpt.model_checkpoint_path) + return True + + +def run_eval(eval_ops, summary_writer, saver): + """Runs evaluation over FLAGS.num_examples examples. + + Args: + eval_ops: dict + summary_writer: Summary writer. + saver: Saver. + + Returns: + dict, with value being the average over all examples. + """ + sv = tf.train.Supervisor(logdir=FLAGS.eval_dir, saver=None, summary_op=None) + with sv.managed_session( + master=FLAGS.master, start_standard_services=False) as sess: + if not restore_from_checkpoint(sess, saver): + return + sv.start_queue_runners(sess) + + metric_names, ops = zip(*eval_ops.items()) + value_ops, update_ops = zip(*ops) + + # Run update ops + num_batches = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size)) + tf.logging.info('Running %d batches for evaluation.', num_batches) + for i in range(num_batches): + if (i + 1) % 10 == 0: + tf.logging.info('Running batch %d/%d...', i + 1, num_batches) + sess.run(update_ops) + + values = sess.run(value_ops) + metric_values = dict(zip(metric_names, values)) + + tf.logging.info('Eval metric values:') + summary = tf.summary.Summary() + for name, val in metric_values.items(): + summary.value.add(tag=name, simple_value=val) + tf.logging.info('%s = %.3f', name, val) + + global_step_val = sess.run(tf.train.get_global_step()) + summary_writer.add_summary(summary, global_step_val) + + return metric_values + + +def main(_): + tf.logging.set_verbosity(tf.logging.INFO) + tf.gfile.MakeDirs(FLAGS.eval_dir) + tf.logging.info('Building eval graph...') + output = graphs.get_model().eval_graph(FLAGS.eval_data) + eval_ops, moving_averaged_variables = output + + saver = tf.train.Saver(moving_averaged_variables) + summary_writer = tf.summary.FileWriter( + FLAGS.eval_dir, graph=tf.get_default_graph()) + + while True: + run_eval(eval_ops, summary_writer, saver) + if FLAGS.run_once: + break + time.sleep(FLAGS.eval_interval_secs) + + +if __name__ == '__main__': + tf.app.run() diff --git a/adversarial_text/graphs.py b/adversarial_text/graphs.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5dce8d0e01cb0ed22d4bbcc56ac43ec11840b1 --- /dev/null +++ b/adversarial_text/graphs.py @@ -0,0 +1,661 @@ +# Copyright 2017 Google, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Virtual adversarial text models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import csv +import os +import tensorflow as tf + +import adversarial_losses as adv_lib +import inputs as inputs_lib +import layers as layers_lib + +flags = tf.app.flags +FLAGS = flags.FLAGS + +# Flags governing adversarial training are defined in adversarial_losses.py. + +# Classifier +flags.DEFINE_integer('num_classes', 2, 'Number of classes for classification') + +# Data path +flags.DEFINE_string('data_dir', '/tmp/IMDB', + 'Directory path to preprocessed text dataset.') +flags.DEFINE_string('vocab_freq_path', None, + 'Path to pre-calculated vocab frequency data. If ' + 'None, use FLAGS.data_dir/vocab_freq.txt.') +flags.DEFINE_integer('batch_size', 64, 'Size of the batch.') +flags.DEFINE_integer('num_timesteps', 100, 'Number of timesteps for BPTT') + +# Model architechture +flags.DEFINE_bool('bidir_lstm', False, 'Whether to build a bidirectional LSTM.') +flags.DEFINE_integer('rnn_num_layers', 1, 'Number of LSTM layers.') +flags.DEFINE_integer('rnn_cell_size', 512, + 'Number of hidden units in the LSTM.') +flags.DEFINE_integer('cl_num_layers', 1, + 'Number of hidden layers of classification model.') +flags.DEFINE_integer('cl_hidden_size', 30, + 'Number of hidden units in classification layer.') +flags.DEFINE_integer('num_candidate_samples', -1, + 'Num samples used in the sampled output layer.') +flags.DEFINE_bool('use_seq2seq_autoencoder', False, + 'If True, seq2seq auto-encoder is used to pretrain. ' + 'If False, standard language model is used.') + +# Vocabulary and embeddings +flags.DEFINE_integer('embedding_dims', 256, 'Dimensions of embedded vector.') +flags.DEFINE_integer('vocab_size', 86934, + 'The size of the vocaburary. This value ' + 'should be exactly same as the number of the ' + 'vocabulary used in dataset. Because the last ' + 'indexed vocabulary of the dataset preprocessed by ' + 'my preprocessed code, is always and here we ' + 'specify the with the the index.') +flags.DEFINE_bool('normalize_embeddings', True, + 'Normalize word embeddings by vocab frequency') + +# Optimization +flags.DEFINE_float('learning_rate', 0.001, 'Learning rate while fine-tuning.') +flags.DEFINE_float('learning_rate_decay_factor', 1.0, + 'Learning rate decay factor') +flags.DEFINE_boolean('sync_replicas', False, 'sync_replica or not') +flags.DEFINE_integer('replicas_to_aggregate', 1, + 'The number of replicas to aggregate') + +# Regularization +flags.DEFINE_float('max_grad_norm', 1.0, + 'Clip the global gradient norm to this value.') +flags.DEFINE_float('keep_prob_emb', 1.0, 'keep probability on embedding layer') +flags.DEFINE_float('keep_prob_lstm_out', 1.0, + 'keep probability on lstm output.') +flags.DEFINE_float('keep_prob_cl_hidden', 1.0, + 'keep probability on classification hidden layer') + + +def get_model(): + if FLAGS.bidir_lstm: + return VatxtBidirModel() + else: + return VatxtModel() + + +class VatxtModel(object): + """Constructs training and evaluation graphs. + + Main methods: `classifier_training()`, `language_model_training()`, + and `eval_graph()`. + + Variable reuse is a critical part of the model, both for sharing variables + between the language model and the classifier, and for reusing variables for + the adversarial loss calculation. To ensure correct variable reuse, all + variables are created in Keras-style layers, wherein stateful layers (i.e. + layers with variables) are represented as callable instances of the Layer + class. Each time the Layer instance is called, it is using the same variables. + + All Layers are constructed in the __init__ method and reused in the various + graph-building functions. + """ + + def __init__(self, cl_logits_input_dim=None): + self.global_step = tf.contrib.framework.get_or_create_global_step() + self.vocab_freqs = _get_vocab_freqs() + + # Cache VatxtInput objects + self.cl_inputs = None + self.lm_inputs = None + + # Cache intermediate Tensors that are reused + self.tensors = {} + + # Construct layers which are reused in constructing the LM and + # Classification graphs. Instantiating them all once here ensures that + # variable reuse works correctly. + self.layers = {} + self.layers['embedding'] = layers_lib.Embedding( + FLAGS.vocab_size, FLAGS.embedding_dims, FLAGS.normalize_embeddings, + self.vocab_freqs, FLAGS.keep_prob_emb) + self.layers['lstm'] = layers_lib.LSTM( + FLAGS.rnn_cell_size, FLAGS.rnn_num_layers, FLAGS.keep_prob_lstm_out) + self.layers['lm_loss'] = layers_lib.SoftmaxLoss( + FLAGS.vocab_size, + FLAGS.num_candidate_samples, + self.vocab_freqs, + name='LM_loss') + + cl_logits_input_dim = cl_logits_input_dim or FLAGS.rnn_cell_size + self.layers['cl_logits'] = layers_lib.cl_logits_subgraph( + [FLAGS.cl_hidden_size] * FLAGS.cl_num_layers, cl_logits_input_dim, + FLAGS.num_classes, FLAGS.keep_prob_cl_hidden) + + @property + def pretrained_variables(self): + return (self.layers['embedding'].trainable_weights + + self.layers['lstm'].trainable_weights) + + def classifier_training(self): + loss = self.classifier_graph() + train_op = optimize(loss, self.global_step) + return train_op, loss, self.global_step + + def language_model_training(self): + loss = self.language_model_graph() + train_op = optimize(loss, self.global_step) + return train_op, loss, self.global_step + + def classifier_graph(self): + """Constructs classifier graph from inputs to classifier loss. + + * Caches the VatxtInput object in `self.cl_inputs` + * Caches tensors: `cl_embedded`, `cl_logits`, `cl_loss` + + Returns: + loss: scalar float. + """ + inputs = _inputs('train', pretrain=False) + self.cl_inputs = inputs + embedded = self.layers['embedding'](inputs.tokens) + self.tensors['cl_embedded'] = embedded + + _, next_state, logits, loss = self.cl_loss_from_embedding( + embedded, return_intermediates=True) + tf.summary.scalar('classification_loss', loss) + self.tensors['cl_logits'] = logits + self.tensors['cl_loss'] = loss + + acc = layers_lib.accuracy(logits, inputs.labels, inputs.weights) + tf.summary.scalar('accuracy', acc) + + adv_loss = (self.adversarial_loss() * tf.constant( + FLAGS.adv_reg_coeff, name='adv_reg_coeff')) + tf.summary.scalar('adversarial_loss', adv_loss) + + total_loss = loss + adv_loss + tf.summary.scalar('total_classification_loss', total_loss) + + with tf.control_dependencies([inputs.save_state(next_state)]): + total_loss = tf.identity(total_loss) + + return total_loss + + def language_model_graph(self, compute_loss=True): + """Constructs LM graph from inputs to LM loss. + + * Caches the VatxtInput object in `self.lm_inputs` + * Caches tensors: `lm_embedded` + + Args: + compute_loss: bool, whether to compute and return the loss or stop after + the LSTM computation. + + Returns: + loss: scalar float. + """ + inputs = _inputs('train', pretrain=True) + self.lm_inputs = inputs + return self._lm_loss(inputs, compute_loss=compute_loss) + + def _lm_loss(self, + inputs, + emb_key='lm_embedded', + lstm_layer='lstm', + lm_loss_layer='lm_loss', + loss_name='lm_loss', + compute_loss=True): + embedded = self.layers['embedding'](inputs.tokens) + self.tensors[emb_key] = embedded + lstm_out, next_state = self.layers[lstm_layer](embedded, inputs.state, + inputs.length) + if compute_loss: + loss = self.layers[lm_loss_layer]( + [lstm_out, inputs.labels, inputs.weights]) + with tf.control_dependencies([inputs.save_state(next_state)]): + loss = tf.identity(loss) + tf.summary.scalar(loss_name, loss) + + return loss + + def eval_graph(self, dataset='test'): + """Constructs classifier evaluation graph. + + Args: + dataset: the labeled dataset to evaluate, {'train', 'test', 'valid'}. + + Returns: + eval_ops: dict + var_restore_dict: dict mapping variable restoration names to variables. + Trainable variables will be mapped to their moving average names. + """ + inputs = _inputs(dataset, pretrain=False) + embedded = self.layers['embedding'](inputs.tokens) + _, next_state, logits, _ = self.cl_loss_from_embedding( + embedded, inputs=inputs, return_intermediates=True) + + eval_ops = { + 'accuracy': + tf.contrib.metrics.streaming_accuracy( + layers_lib.predictions(logits), inputs.labels, + inputs.weights) + } + + with tf.control_dependencies([inputs.save_state(next_state)]): + acc, acc_update = eval_ops['accuracy'] + acc_update = tf.identity(acc_update) + eval_ops['accuracy'] = (acc, acc_update) + + var_restore_dict = make_restore_average_vars_dict() + return eval_ops, var_restore_dict + + def cl_loss_from_embedding(self, + embedded, + inputs=None, + return_intermediates=False): + """Compute classification loss from embedding. + + Args: + embedded: 3-D float Tensor [batch_size, num_timesteps, embedding_dim] + inputs: VatxtInput, defaults to self.cl_inputs. + return_intermediates: bool, whether to return intermediate tensors or only + the final loss. + + Returns: + If return_intermediates is True: + lstm_out, next_state, logits, loss + Else: + loss + """ + if inputs is None: + inputs = self.cl_inputs + + lstm_out, next_state = self.layers['lstm'](embedded, inputs.state, + inputs.length) + logits = self.layers['cl_logits'](lstm_out) + loss = layers_lib.classification_loss(logits, inputs.labels, inputs.weights) + + if return_intermediates: + return lstm_out, next_state, logits, loss + else: + return loss + + def adversarial_loss(self): + """Compute adversarial loss based on FLAGS.adv_training_method.""" + + def random_perturbation_loss(): + return adv_lib.random_perturbation_loss(self.tensors['cl_embedded'], + self.cl_inputs.length, + self.cl_loss_from_embedding) + + def adversarial_loss(): + return adv_lib.adversarial_loss(self.tensors['cl_embedded'], + self.tensors['cl_loss'], + self.cl_loss_from_embedding) + + def virtual_adversarial_loss(): + """Computes virtual adversarial loss. + + Uses lm_inputs and constructs the language model graph if it hasn't yet + been constructed. + + Also ensures that the LM input states are saved for LSTM state-saving + BPTT. + + Returns: + loss: float scalar. + """ + if self.lm_inputs is None: + self.language_model_graph(compute_loss=False) + + def logits_from_embedding(embedded, return_next_state=False): + _, next_state, logits, _ = self.cl_loss_from_embedding( + embedded, inputs=self.lm_inputs, return_intermediates=True) + if return_next_state: + return next_state, logits + else: + return logits + + next_state, lm_cl_logits = logits_from_embedding( + self.tensors['lm_embedded'], return_next_state=True) + + va_loss = adv_lib.virtual_adversarial_loss( + lm_cl_logits, self.tensors['lm_embedded'], self.lm_inputs, + logits_from_embedding) + + with tf.control_dependencies([self.lm_inputs.save_state(next_state)]): + va_loss = tf.identity(va_loss) + + return va_loss + + def combo_loss(): + return adversarial_loss() + virtual_adversarial_loss() + + adv_training_methods = { + # Random perturbation + 'rp': random_perturbation_loss, + # Adversarial training + 'at': adversarial_loss, + # Virtual adversarial training + 'vat': virtual_adversarial_loss, + # Both at and vat + 'atvat': combo_loss, + '': lambda: tf.constant(0.), + None: lambda: tf.constant(0.), + } + + with tf.name_scope('adversarial_loss'): + return adv_training_methods[FLAGS.adv_training_method]() + + +class VatxtBidirModel(VatxtModel): + """Extension of VatxtModel that supports bidirectional input.""" + + def __init__(self): + super(VatxtBidirModel, + self).__init__(cl_logits_input_dim=FLAGS.rnn_cell_size * 2) + + # Reverse LSTM and LM loss for bidirectional models + self.layers['lstm_reverse'] = layers_lib.LSTM( + FLAGS.rnn_cell_size, + FLAGS.rnn_num_layers, + FLAGS.keep_prob_lstm_out, + name='LSTM_Reverse') + self.layers['lm_loss_reverse'] = layers_lib.SoftmaxLoss( + FLAGS.vocab_size, + FLAGS.num_candidate_samples, + self.vocab_freqs, + name='LM_loss_reverse') + + @property + def pretrained_variables(self): + variables = super(VatxtBidirModel, self).pretrained_variables + variables.extend(self.layers['lstm_reverse'].trainable_weights) + return variables + + def classifier_graph(self): + """Constructs classifier graph from inputs to classifier loss. + + * Caches the VatxtInput objects in `self.cl_inputs` + * Caches tensors: `cl_embedded` (tuple of forward and reverse), `cl_logits`, + `cl_loss` + + Returns: + loss: scalar float. + """ + inputs = _inputs('train', pretrain=False, bidir=True) + self.cl_inputs = inputs + f_inputs, _ = inputs + + # Embed both forward and reverse with a shared embedding + embedded = [self.layers['embedding'](inp.tokens) for inp in inputs] + self.tensors['cl_embedded'] = embedded + + _, next_states, logits, loss = self.cl_loss_from_embedding( + embedded, return_intermediates=True) + tf.summary.scalar('classification_loss', loss) + self.tensors['cl_logits'] = logits + self.tensors['cl_loss'] = loss + + acc = layers_lib.accuracy(logits, f_inputs.labels, f_inputs.weights) + tf.summary.scalar('accuracy', acc) + + adv_loss = (self.adversarial_loss() * tf.constant( + FLAGS.adv_reg_coeff, name='adv_reg_coeff')) + tf.summary.scalar('adversarial_loss', adv_loss) + + total_loss = loss + adv_loss + tf.summary.scalar('total_classification_loss', total_loss) + + saves = [inp.save_state(state) for (inp, state) in zip(inputs, next_states)] + with tf.control_dependencies(saves): + total_loss = tf.identity(total_loss) + + return total_loss + + def language_model_graph(self, compute_loss=True): + """Constructs forward and reverse LM graphs from inputs to LM losses. + + * Caches the VatxtInput objects in `self.lm_inputs` + * Caches tensors: `lm_embedded`, `lm_embedded_reverse` + + Args: + compute_loss: bool, whether to compute and return the loss or stop after + the LSTM computation. + + Returns: + loss: scalar float, sum of forward and reverse losses. + """ + inputs = _inputs('train', pretrain=True, bidir=True) + self.lm_inputs = inputs + f_inputs, r_inputs = inputs + f_loss = self._lm_loss(f_inputs, compute_loss=compute_loss) + r_loss = self._lm_loss( + r_inputs, + emb_key='lm_embedded_reverse', + lstm_layer='lstm_reverse', + lm_loss_layer='lm_loss_reverse', + loss_name='lm_loss_reverse', + compute_loss=compute_loss) + if compute_loss: + return f_loss + r_loss + + def eval_graph(self, dataset='test'): + """Constructs classifier evaluation graph. + + Args: + dataset: the labeled dataset to evaluate, {'train', 'test', 'valid'}. + + Returns: + eval_ops: dict + var_restore_dict: dict mapping variable restoration names to variables. + Trainable variables will be mapped to their moving average names. + """ + inputs = _inputs(dataset, pretrain=False, bidir=True) + embedded = [self.layers['embedding'](inp.tokens) for inp in inputs] + _, next_states, logits, _ = self.cl_loss_from_embedding( + embedded, inputs=inputs, return_intermediates=True) + f_inputs, _ = inputs + + eval_ops = { + 'accuracy': + tf.contrib.metrics.streaming_accuracy( + layers_lib.predictions(logits), f_inputs.labels, + f_inputs.weights) + } + + # Save states on accuracy update + saves = [inp.save_state(state) for (inp, state) in zip(inputs, next_states)] + with tf.control_dependencies(saves): + acc, acc_update = eval_ops['accuracy'] + acc_update = tf.identity(acc_update) + eval_ops['accuracy'] = (acc, acc_update) + + var_restore_dict = make_restore_average_vars_dict() + return eval_ops, var_restore_dict + + def cl_loss_from_embedding(self, + embedded, + inputs=None, + return_intermediates=False): + """Compute classification loss from embedding. + + Args: + embedded: Length 2 tuple of 3-D float Tensor + [batch_size, num_timesteps, embedding_dim]. + inputs: Length 2 tuple of VatxtInput, defaults to self.cl_inputs. + return_intermediates: bool, whether to return intermediate tensors or only + the final loss. + + Returns: + If return_intermediates is True: + lstm_out, next_states, logits, loss + Else: + loss + """ + if inputs is None: + inputs = self.cl_inputs + + out = [] + for (layer_name, emb, inp) in zip(['lstm', 'lstm_reverse'], embedded, + inputs): + out.append(self.layers[layer_name](emb, inp.state, inp.length)) + lstm_outs, next_states = zip(*out) + + # Concatenate output of forward and reverse LSTMs + lstm_out = tf.concat(lstm_outs, 1) + + logits = self.layers['cl_logits'](lstm_out) + f_inputs, _ = inputs # pylint: disable=unpacking-non-sequence + loss = layers_lib.classification_loss(logits, f_inputs.labels, + f_inputs.weights) + + if return_intermediates: + return lstm_out, next_states, logits, loss + else: + return loss + + def adversarial_loss(self): + """Compute adversarial loss based on FLAGS.adv_training_method.""" + + def random_perturbation_loss(): + return adv_lib.random_perturbation_loss_bidir(self.tensors['cl_embedded'], + self.cl_inputs[0].length, + self.cl_loss_from_embedding) + + def adversarial_loss(): + return adv_lib.adversarial_loss_bidir(self.tensors['cl_embedded'], + self.tensors['cl_loss'], + self.cl_loss_from_embedding) + + def virtual_adversarial_loss(): + """Computes virtual adversarial loss. + + Uses lm_inputs and constructs the language model graph if it hasn't yet + been constructed. + + Also ensures that the LM input states are saved for LSTM state-saving + BPTT. + + Returns: + loss: float scalar. + """ + if self.lm_inputs is None: + self.language_model_graph(compute_loss=False) + + def logits_from_embedding(embedded, return_next_state=False): + _, next_states, logits, _ = self.cl_loss_from_embedding( + embedded, inputs=self.lm_inputs, return_intermediates=True) + if return_next_state: + return next_states, logits + else: + return logits + + lm_embedded = (self.tensors['lm_embedded'], + self.tensors['lm_embedded_reverse']) + next_states, lm_cl_logits = logits_from_embedding( + lm_embedded, return_next_state=True) + + va_loss = adv_lib.virtual_adversarial_loss_bidir( + lm_cl_logits, lm_embedded, self.lm_inputs, logits_from_embedding) + + saves = [ + inp.save_state(state) + for (inp, state) in zip(self.lm_inputs, next_states) + ] + with tf.control_dependencies(saves): + va_loss = tf.identity(va_loss) + + return va_loss + + def combo_loss(): + return adversarial_loss() + virtual_adversarial_loss() + + adv_training_methods = { + # Random perturbation + 'rp': random_perturbation_loss, + # Adversarial training + 'at': adversarial_loss, + # Virtual adversarial training + 'vat': virtual_adversarial_loss, + # Both at and vat + 'atvat': combo_loss, + '': lambda: tf.constant(0.), + None: lambda: tf.constant(0.), + } + + with tf.name_scope('adversarial_loss'): + return adv_training_methods[FLAGS.adv_training_method]() + + +def _inputs(dataset='train', pretrain=False, bidir=False): + return inputs_lib.inputs( + data_dir=FLAGS.data_dir, + phase=dataset, + bidir=bidir, + pretrain=pretrain, + use_seq2seq=pretrain and FLAGS.use_seq2seq_autoencoder, + state_size=FLAGS.rnn_cell_size, + num_layers=FLAGS.rnn_num_layers, + batch_size=FLAGS.batch_size, + unroll_steps=FLAGS.num_timesteps) + + +def _get_vocab_freqs(): + """Returns vocab frequencies. + + Returns: + List of integers, length=FLAGS.vocab_size. + + Raises: + ValueError: if the length of the frequency file is not equal to the vocab + size, or if the file is not found. + """ + path = FLAGS.vocab_freq_path or os.path.join(FLAGS.data_dir, 'vocab_freq.txt') + + if tf.gfile.Exists(path): + with tf.gfile.Open(path) as f: + # Get pre-calculated frequencies of words. + reader = csv.reader(f, quoting=csv.QUOTE_NONE) + freqs = [int(row[-1]) for row in reader] + if len(freqs) != FLAGS.vocab_size: + raise ValueError('Frequency file length %d != vocab size %d' % + (len(freqs), FLAGS.vocab_size)) + else: + if FLAGS.vocab_freq_path: + raise ValueError('vocab_freq_path not found') + freqs = [1] * FLAGS.vocab_size + + return freqs + + +def make_restore_average_vars_dict(): + """Returns dict mapping moving average names to variables.""" + var_restore_dict = {} + variable_averages = tf.train.ExponentialMovingAverage(0.999) + for v in tf.global_variables(): + if v in tf.trainable_variables(): + name = variable_averages.average_name(v) + else: + name = v.op.name + var_restore_dict[name] = v + return var_restore_dict + + +def optimize(loss, global_step): + return layers_lib.optimize( + loss, global_step, FLAGS.max_grad_norm, FLAGS.learning_rate, + FLAGS.learning_rate_decay_factor, FLAGS.sync_replicas, + FLAGS.replicas_to_aggregate, FLAGS.task) diff --git a/adversarial_text/graphs_test.py b/adversarial_text/graphs_test.py new file mode 100644 index 0000000000000000000000000000000000000000..849e3d06f9f5d51eb4c6b81fe568a16a90fd962c --- /dev/null +++ b/adversarial_text/graphs_test.py @@ -0,0 +1,224 @@ +# Copyright 2017 Google, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for graphs.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict +import operator +import os +import random +import shutil +import string +import tempfile + +import tensorflow as tf + +import graphs +from adversarial_text.data import data_utils + +flags = tf.app.flags +FLAGS = flags.FLAGS +data = data_utils + +flags.DEFINE_integer('task', 0, 'Task id; needed for SyncReplicas test') + + +def _build_random_vocabulary(vocab_size=100): + """Builds and returns a dict.""" + vocab = set() + while len(vocab) < (vocab_size - 1): + rand_word = ''.join( + random.choice(string.ascii_lowercase) + for _ in range(random.randint(1, 10))) + vocab.add(rand_word) + + vocab_ids = dict([(word, i) for i, word in enumerate(vocab)]) + vocab_ids[data.EOS_TOKEN] = vocab_size - 1 + return vocab_ids + + +def _build_random_sequence(vocab_ids): + seq_len = random.randint(10, 200) + ids = vocab_ids.values() + seq = data.SequenceWrapper() + for token_id in [random.choice(ids) for _ in range(seq_len)]: + seq.add_timestep().set_token(token_id) + return seq + + +def _build_vocab_frequencies(seqs, vocab_ids): + vocab_freqs = defaultdict(int) + ids_to_words = dict([(i, word) for word, i in vocab_ids.iteritems()]) + for seq in seqs: + for timestep in seq: + vocab_freqs[ids_to_words[timestep.token]] += 1 + + vocab_freqs[data.EOS_TOKEN] = 0 + return vocab_freqs + + +class GraphsTest(tf.test.TestCase): + """Test graph construction methods.""" + + @classmethod + def setUpClass(cls): + # Make model small + FLAGS.batch_size = 2 + FLAGS.num_timesteps = 3 + FLAGS.embedding_dims = 4 + FLAGS.rnn_num_layers = 2 + FLAGS.rnn_cell_size = 4 + FLAGS.cl_num_layers = 2 + FLAGS.cl_hidden_size = 4 + FLAGS.vocab_size = 10 + + # Set input/output flags + FLAGS.data_dir = tempfile.mkdtemp() + + # Build and write sequence files. + vocab_ids = _build_random_vocabulary(FLAGS.vocab_size) + seqs = [_build_random_sequence(vocab_ids) for _ in range(5)] + seqs_label = [ + data.build_labeled_sequence(seq, random.choice([True, False])) + for seq in seqs + ] + seqs_lm = [data.build_lm_sequence(seq) for seq in seqs] + seqs_ae = [data.build_seq_ae_sequence(seq) for seq in seqs] + seqs_rev = [data.build_reverse_sequence(seq) for seq in seqs] + seqs_bidir = [ + data.build_bidirectional_seq(seq, rev) + for seq, rev in zip(seqs, seqs_rev) + ] + seqs_bidir_label = [ + data.build_labeled_sequence(bd_seq, random.choice([True, False])) + for bd_seq in seqs_bidir + ] + + filenames = [ + data.TRAIN_CLASS, data.TRAIN_LM, data.TRAIN_SA, data.TEST_CLASS, + data.TRAIN_REV_LM, data.TRAIN_BD_CLASS, data.TEST_BD_CLASS + ] + seq_lists = [ + seqs_label, seqs_lm, seqs_ae, seqs_label, seqs_rev, seqs_bidir, + seqs_bidir_label + ] + for fname, seq_list in zip(filenames, seq_lists): + with tf.python_io.TFRecordWriter( + os.path.join(FLAGS.data_dir, fname)) as writer: + for seq in seq_list: + writer.write(seq.seq.SerializeToString()) + + # Write vocab.txt and vocab_freq.txt + vocab_freqs = _build_vocab_frequencies(seqs, vocab_ids) + ordered_vocab_freqs = sorted( + vocab_freqs.items(), key=operator.itemgetter(1), reverse=True) + with open(os.path.join(FLAGS.data_dir, 'vocab.txt'), 'w') as vocab_f: + with open(os.path.join(FLAGS.data_dir, 'vocab_freq.txt'), 'w') as freq_f: + for word, freq in ordered_vocab_freqs: + vocab_f.write('{}\n'.format(word)) + freq_f.write('{}\n'.format(freq)) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(FLAGS.data_dir) + + def setUp(self): + # Reset FLAGS + FLAGS.rnn_num_layers = 1 + FLAGS.sync_replicas = False + FLAGS.adv_training_method = None + FLAGS.num_candidate_samples = -1 + FLAGS.num_classes = 2 + FLAGS.use_seq2seq_autoencoder = False + + # Reset Graph + tf.reset_default_graph() + + def testClassifierGraph(self): + FLAGS.rnn_num_layers = 2 + model = graphs.VatxtModel() + train_op, _, _ = model.classifier_training() + # Pretrained vars: embedding + LSTM layers + self.assertEqual( + len(model.pretrained_variables), 1 + 2 * FLAGS.rnn_num_layers) + with self.test_session() as sess: + sess.run(tf.global_variables_initializer()) + tf.train.start_queue_runners(sess) + sess.run(train_op) + + def testLanguageModelGraph(self): + train_op, _, _ = graphs.VatxtModel().language_model_training() + with self.test_session() as sess: + sess.run(tf.global_variables_initializer()) + tf.train.start_queue_runners(sess) + sess.run(train_op) + + def testMulticlass(self): + FLAGS.num_classes = 10 + graphs.VatxtModel().classifier_graph() + + def testATMethods(self): + at_methods = [None, 'rp', 'at', 'vat', 'atvat'] + for method in at_methods: + FLAGS.adv_training_method = method + with tf.Graph().as_default(): + graphs.VatxtModel().classifier_graph() + + # Ensure variables have been reused + # Embedding + LSTM layers + hidden layers + logits layer + expected_num_vars = 1 + 2 * FLAGS.rnn_num_layers + 2 * ( + FLAGS.cl_num_layers) + 2 + self.assertEqual(len(tf.trainable_variables()), expected_num_vars) + + def testSyncReplicas(self): + FLAGS.sync_replicas = True + graphs.VatxtModel().language_model_training() + + def testCandidateSampling(self): + FLAGS.num_candidate_samples = 10 + graphs.VatxtModel().language_model_training() + + def testSeqAE(self): + FLAGS.use_seq2seq_autoencoder = True + graphs.VatxtModel().language_model_training() + + def testBidirLM(self): + graphs.VatxtBidirModel().language_model_graph() + + def testBidirClassifier(self): + at_methods = [None, 'rp', 'at', 'vat', 'atvat'] + for method in at_methods: + FLAGS.adv_training_method = method + with tf.Graph().as_default(): + graphs.VatxtBidirModel().classifier_graph() + + # Ensure variables have been reused + # Embedding + 2 LSTM layers + hidden layers + logits layer + expected_num_vars = 1 + 2 * 2 * FLAGS.rnn_num_layers + 2 * ( + FLAGS.cl_num_layers) + 2 + self.assertEqual(len(tf.trainable_variables()), expected_num_vars) + + def testEvalGraph(self): + _, _ = graphs.VatxtModel().eval_graph() + + def testBidirEvalGraph(self): + _, _ = graphs.VatxtBidirModel().eval_graph() + + +if __name__ == '__main__': + tf.test.main() diff --git a/adversarial_text/inputs.py b/adversarial_text/inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..ec99eded05ed4abf4dfbf8e52d36c0ef796026ba --- /dev/null +++ b/adversarial_text/inputs.py @@ -0,0 +1,325 @@ +# Copyright 2017 Google, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Input utils for virtual adversarial text classification.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tensorflow as tf + +from adversarial_text.data import data_utils + + +class VatxtInput(object): + """Wrapper around NextQueuedSequenceBatch.""" + + def __init__(self, batch, state_name=None, tokens=None, num_states=0): + """Construct VatxtInput. + + Args: + batch: NextQueuedSequenceBatch. + state_name: str, name of state to fetch and save. + tokens: int Tensor, tokens. Defaults to batch's F_TOKEN_ID sequence. + num_states: int The number of states to store. + """ + self._batch = batch + self._state_name = state_name + self._tokens = (tokens if tokens is not None else + batch.sequences[data_utils.SequenceWrapper.F_TOKEN_ID]) + self._num_states = num_states + + # Once the tokens have passed through embedding and LSTM, the output Tensor + # shapes will be time-major, i.e. shape = (time, batch, dim). Here we make + # both weights and labels time-major with a transpose, and then merge the + # time and batch dimensions such that they are both vectors of shape + # (time*batch). + w = batch.sequences[data_utils.SequenceWrapper.F_WEIGHT] + w = tf.transpose(w, [1, 0]) + w = tf.reshape(w, [-1]) + self._weights = w + + l = batch.sequences[data_utils.SequenceWrapper.F_LABEL] + l = tf.transpose(l, [1, 0]) + l = tf.reshape(l, [-1]) + self._labels = l + + @property + def tokens(self): + return self._tokens + + @property + def weights(self): + return self._weights + + @property + def labels(self): + return self._labels + + @property + def length(self): + return self._batch.length + + @property + def state_name(self): + return self._state_name + + @property + def state(self): + # LSTM tuple states + state_names = _get_tuple_state_names(self._num_states, self._state_name) + return tuple([ + tf.contrib.rnn.LSTMStateTuple( + self._batch.state(c_name), self._batch.state(h_name)) + for c_name, h_name in state_names + ]) + + def save_state(self, value): + # LSTM tuple states + state_names = _get_tuple_state_names(self._num_states, self._state_name) + save_ops = [] + for (c_state, h_state), (c_name, h_name) in zip(value, state_names): + save_ops.append(self._batch.save_state(c_name, c_state)) + save_ops.append(self._batch.save_state(h_name, h_state)) + return tf.group(*save_ops) + + +def _get_tuple_state_names(num_states, base_name): + """Returns state names for use with LSTM tuple state.""" + state_names = [('{}_{}_c'.format(i, base_name), '{}_{}_h'.format( + i, base_name)) for i in range(num_states)] + return state_names + + +def _split_bidir_tokens(batch): + tokens = batch.sequences[data_utils.SequenceWrapper.F_TOKEN_ID] + # Tokens have shape [batch, time, 2] + # forward and reverse have shape [batch, time]. + forward, reverse = [ + tf.squeeze(t, axis=[2]) for t in tf.split(tokens, 2, axis=2) + ] + return forward, reverse + + +def _filenames_for_data_spec(phase, bidir, pretrain, use_seq2seq): + """Returns input filenames for configuration. + + Args: + phase: str, 'train', 'test', or 'valid'. + bidir: bool, bidirectional model. + pretrain: bool, pretraining or classification. + use_seq2seq: bool, seq2seq data, only valid if pretrain=True. + + Returns: + Tuple of filenames. + + Raises: + ValueError: if an invalid combination of arguments is provided that does not + map to any data files (e.g. pretrain=False, use_seq2seq=True). + """ + data_spec = (phase, bidir, pretrain, use_seq2seq) + data_specs = { + ('train', True, True, False): (data_utils.TRAIN_LM, + data_utils.TRAIN_REV_LM), + ('train', True, False, False): (data_utils.TRAIN_BD_CLASS,), + ('train', False, True, False): (data_utils.TRAIN_LM,), + ('train', False, True, True): (data_utils.TRAIN_SA,), + ('train', False, False, False): (data_utils.TRAIN_CLASS,), + ('test', True, True, False): (data_utils.TEST_LM, + data_utils.TRAIN_REV_LM), + ('test', True, False, False): (data_utils.TEST_BD_CLASS,), + ('test', False, True, False): (data_utils.TEST_LM,), + ('test', False, True, True): (data_utils.TEST_SA,), + ('test', False, False, False): (data_utils.TEST_CLASS,), + ('valid', True, False, False): (data_utils.VALID_BD_CLASS,), + ('valid', False, False, False): (data_utils.VALID_CLASS,), + } + if data_spec not in data_specs: + raise ValueError( + 'Data specification (phase, bidir, pretrain, use_seq2seq) %s not ' + 'supported' % str(data_spec)) + + return data_specs[data_spec] + + +def _read_single_sequence_example(file_list, tokens_shape=None): + """Reads and parses SequenceExamples from TFRecord-encoded file_list.""" + tf.logging.info('Constructing TFRecordReader from files: %s', file_list) + file_queue = tf.train.string_input_producer(file_list) + reader = tf.TFRecordReader() + seq_key, serialized_record = reader.read(file_queue) + ctx, sequence = tf.parse_single_sequence_example( + serialized_record, + sequence_features={ + data_utils.SequenceWrapper.F_TOKEN_ID: + tf.FixedLenSequenceFeature(tokens_shape or [], dtype=tf.int64), + data_utils.SequenceWrapper.F_LABEL: + tf.FixedLenSequenceFeature([], dtype=tf.int64), + data_utils.SequenceWrapper.F_WEIGHT: + tf.FixedLenSequenceFeature([], dtype=tf.float32), + }) + return seq_key, ctx, sequence + + +def _read_and_batch(data_dir, + fname, + state_name, + state_size, + num_layers, + unroll_steps, + batch_size, + bidir_input=False): + """Inputs for text model. + + Args: + data_dir: str, directory containing TFRecord files of SequenceExample. + fname: str, input file name. + state_name: string, key for saved state of LSTM. + state_size: int, size of LSTM state. + num_layers: int, the number of layers in the LSTM. + unroll_steps: int, number of timesteps to unroll for TBTT. + batch_size: int, batch size. + bidir_input: bool, whether the input is bidirectional. If True, creates 2 + states, state_name and state_name + '_reverse'. + + Returns: + Instance of NextQueuedSequenceBatch + + Raises: + ValueError: if file for input specification is not found. + """ + data_path = os.path.join(data_dir, fname) + if not tf.gfile.Exists(data_path): + raise ValueError('Failed to find file: %s' % data_path) + + tokens_shape = [2] if bidir_input else [] + seq_key, ctx, sequence = _read_single_sequence_example( + [data_path], tokens_shape=tokens_shape) + # Set up stateful queue reader. + state_names = _get_tuple_state_names(num_layers, state_name) + initial_states = {} + for c_state, h_state in state_names: + initial_states[c_state] = tf.zeros(state_size) + initial_states[h_state] = tf.zeros(state_size) + if bidir_input: + rev_state_names = _get_tuple_state_names(num_layers, + '{}_reverse'.format(state_name)) + for rev_c_state, rev_h_state in rev_state_names: + initial_states[rev_c_state] = tf.zeros(state_size) + initial_states[rev_h_state] = tf.zeros(state_size) + batch = tf.contrib.training.batch_sequences_with_states( + input_key=seq_key, + input_sequences=sequence, + input_context=ctx, + input_length=tf.shape(sequence['token_id'])[0], + initial_states=initial_states, + num_unroll=unroll_steps, + batch_size=batch_size, + allow_small_batch=False, + num_threads=4, + capacity=batch_size * 10, + make_keys_unique=True, + make_keys_unique_seed=29392) + return batch + + +def inputs(data_dir=None, + phase='train', + bidir=False, + pretrain=False, + use_seq2seq=False, + state_name='lstm', + state_size=None, + num_layers=0, + batch_size=32, + unroll_steps=100): + """Inputs for text model. + + Args: + data_dir: str, directory containing TFRecord files of SequenceExample. + phase: str, dataset for evaluation {'train', 'valid', 'test'}. + bidir: bool, bidirectional LSTM. + pretrain: bool, whether to read pretraining data or classification data. + use_seq2seq: bool, whether to read seq2seq data or the language model data. + state_name: string, key for saved state of LSTM. + state_size: int, size of LSTM state. + num_layers: int, the number of LSTM layers. + batch_size: int, batch size. + unroll_steps: int, number of timesteps to unroll for TBTT. + + Returns: + Instance of VatxtInput (x2 if bidir=True and pretrain=True, i.e. forward and + reverse). + """ + with tf.name_scope('inputs'): + filenames = _filenames_for_data_spec(phase, bidir, pretrain, use_seq2seq) + + if bidir and pretrain: + # Bidirectional pretraining + # Requires separate forward and reverse language model data. + forward_fname, reverse_fname = filenames + forward_batch = _read_and_batch(data_dir, forward_fname, state_name, + state_size, num_layers, unroll_steps, + batch_size) + state_name_rev = state_name + '_reverse' + reverse_batch = _read_and_batch(data_dir, reverse_fname, state_name_rev, + state_size, num_layers, unroll_steps, + batch_size) + forward_input = VatxtInput( + forward_batch, state_name=state_name, num_states=num_layers) + reverse_input = VatxtInput( + reverse_batch, state_name=state_name_rev, num_states=num_layers) + return forward_input, reverse_input + + elif bidir: + # Classifier bidirectional LSTM + # Shared data source, but separate token/state streams + fname, = filenames + batch = _read_and_batch( + data_dir, + fname, + state_name, + state_size, + num_layers, + unroll_steps, + batch_size, + bidir_input=True) + forward_tokens, reverse_tokens = _split_bidir_tokens(batch) + forward_input = VatxtInput( + batch, + state_name=state_name, + tokens=forward_tokens, + num_states=num_layers) + reverse_input = VatxtInput( + batch, + state_name=state_name + '_reverse', + tokens=reverse_tokens, + num_states=num_layers) + return forward_input, reverse_input + else: + # Unidirectional LM or classifier + fname, = filenames + batch = _read_and_batch( + data_dir, + fname, + state_name, + state_size, + num_layers, + unroll_steps, + batch_size, + bidir_input=False) + return VatxtInput(batch, state_name=state_name, num_states=num_layers) diff --git a/adversarial_text/layers.py b/adversarial_text/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..719928ea2bad2b987b33b50414790b92c6fc0043 --- /dev/null +++ b/adversarial_text/layers.py @@ -0,0 +1,388 @@ +# Copyright 2017 Google, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Layers for VatxtModel.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +K = tf.contrib.keras + + +def cl_logits_subgraph(layer_sizes, input_size, num_classes, keep_prob=1.): + """Construct multiple ReLU layers with dropout and a linear layer.""" + subgraph = K.models.Sequential(name='cl_logits') + for i, layer_size in enumerate(layer_sizes): + if i == 0: + subgraph.add( + K.layers.Dense(layer_size, activation='relu', input_dim=input_size)) + else: + subgraph.add(K.layers.Dense(layer_size, activation='relu')) + + if keep_prob < 1.: + subgraph.add(K.layers.Dropout(keep_prob)) + subgraph.add(K.layers.Dense(1 if num_classes == 2 else num_classes)) + return subgraph + + +class Embedding(K.layers.Layer): + """Embedding layer with frequency-based normalization and dropout.""" + + def __init__(self, + vocab_size, + embedding_dim, + normalize=False, + vocab_freqs=None, + keep_prob=1., + **kwargs): + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.normalized = normalize + self.keep_prob = keep_prob + + if normalize: + assert vocab_freqs is not None + self.vocab_freqs = tf.constant( + vocab_freqs, dtype=tf.float32, shape=(vocab_size, 1)) + + super(Embedding, self).__init__(**kwargs) + + def build(self, input_shape): + with tf.device('/cpu:0'): + self.var = self.add_weight( + shape=(self.vocab_size, self.embedding_dim), + initializer=tf.random_uniform_initializer(-1., 1.), + name='embedding') + + if self.normalized: + self.var = self._normalize(self.var) + + super(Embedding, self).build(input_shape) + + def call(self, x): + embedded = tf.nn.embedding_lookup(self.var, x) + if self.keep_prob < 1.: + embedded = tf.nn.dropout(embedded, self.keep_prob) + return embedded + + def _normalize(self, emb): + weights = self.vocab_freqs / tf.reduce_sum(self.vocab_freqs) + + emb -= tf.reduce_sum(weights * emb, 0, keep_dims=True) + emb /= tf.sqrt(1e-6 + tf.reduce_sum( + weights * tf.pow(emb, 2.), 0, keep_dims=True)) + return emb + + +class LSTM(object): + """LSTM layer using static_rnn. + + Exposes variables in `trainable_weights` property. + """ + + def __init__(self, cell_size, num_layers=1, keep_prob=1., name='LSTM'): + self.cell_size = cell_size + self.num_layers = num_layers + self.keep_prob = keep_prob + self.reuse = None + self.trainable_weights = None + self.name = name + + def __call__(self, x, initial_state, seq_length): + with tf.variable_scope(self.name, reuse=self.reuse) as vs: + cell = tf.contrib.rnn.MultiRNNCell([ + tf.contrib.rnn.BasicLSTMCell( + self.cell_size, + forget_bias=0.0, + reuse=tf.get_variable_scope().reuse) + for _ in xrange(self.num_layers) + ]) + + # shape(x) = (batch_size, num_timesteps, embedding_dim) + # Convert into a time-major list for static_rnn + x = tf.unstack(tf.transpose(x, perm=[1, 0, 2])) + + lstm_out, next_state = tf.contrib.rnn.static_rnn( + cell, x, initial_state=initial_state, sequence_length=seq_length) + + # Merge time and batch dimensions + # shape(lstm_out) = timesteps * (batch_size, cell_size) + lstm_out = tf.concat(lstm_out, 0) + # shape(lstm_out) = (timesteps*batch_size, cell_size) + + if self.keep_prob < 1.: + lstm_out = tf.nn.dropout(lstm_out, self.keep_prob) + + if self.reuse is None: + self.trainable_weights = vs.global_variables() + + self.reuse = True + + return lstm_out, next_state + + +class SoftmaxLoss(K.layers.Layer): + """Softmax xentropy loss with candidate sampling.""" + + def __init__(self, + vocab_size, + num_candidate_samples=-1, + vocab_freqs=None, + **kwargs): + self.vocab_size = vocab_size + self.num_candidate_samples = num_candidate_samples + self.vocab_freqs = vocab_freqs + super(SoftmaxLoss, self).__init__(**kwargs) + + def build(self, input_shape): + input_shape = input_shape[0] + with tf.device('/cpu:0'): + self.lin_w = self.add_weight( + shape=(input_shape[-1], self.vocab_size), + name='lm_lin_w', + initializer='glorot_uniform') + self.lin_b = self.add_weight( + shape=(self.vocab_size,), + name='lm_lin_b', + initializer='glorot_uniform') + + super(SoftmaxLoss, self).build(input_shape) + + def call(self, inputs): + x, labels, weights = inputs + if self.num_candidate_samples > -1: + assert self.vocab_freqs is not None + labels = tf.expand_dims(labels, -1) + sampled = tf.nn.fixed_unigram_candidate_sampler( + true_classes=labels, + num_true=1, + num_sampled=self.num_candidate_samples, + unique=True, + range_max=self.vocab_size, + unigrams=self.vocab_freqs) + + lm_loss = tf.nn.sampled_softmax_loss( + weights=tf.transpose(self.lin_w), + biases=self.lin_b, + labels=labels, + inputs=x, + num_sampled=self.num_candidate_samples, + num_classes=self.vocab_size, + sampled_values=sampled) + else: + logits = tf.matmul(x, self.lin_w) + self.lin_b + lm_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels) + + lm_loss = tf.identity( + tf.reduce_sum(lm_loss * weights) / _num_labels(weights), + name='lm_xentropy_loss') + return lm_loss + + +def classification_loss(logits, labels, weights): + """Computes cross entropy loss between logits and labels. + + Args: + logits: 2-D [timesteps*batch_size, m] float tensor, where m=1 if + num_classes=2, otherwise m=num_classes. + labels: 1-D [timesteps*batch_size] integer tensor. + weights: 2-D [timesteps*batch_size] float tensor. + + Returns: + Loss scalar of type float. + """ + inner_dim = logits.get_shape().as_list()[-1] + with tf.name_scope('classifier_loss'): + # Logistic loss + if inner_dim == 1: + loss = tf.nn.sigmoid_cross_entropy_with_logits( + logits=tf.squeeze(logits), labels=tf.cast(labels, tf.float32)) + # Softmax loss + else: + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels) + + num_lab = _num_labels(weights) + tf.summary.scalar('num_labels', num_lab) + return tf.identity( + tf.reduce_sum(weights * loss) / num_lab, name='classification_xentropy') + + +def accuracy(logits, targets, weights): + """Computes prediction accuracy. + + Args: + logits: 2-D classifier logits [timesteps*batch_size, num_classes] + targets: 1-D [timesteps*batch_size] integer tensor. + weights: 1-D [timesteps*batch_size] float tensor. + + Returns: + Accuracy: float scalar. + """ + with tf.name_scope('accuracy'): + eq = tf.cast(tf.equal(predictions(logits), targets), tf.float32) + return tf.identity( + tf.reduce_sum(weights * eq) / _num_labels(weights), name='accuracy') + + +def predictions(logits): + """Class prediction from logits.""" + inner_dim = logits.get_shape().as_list()[-1] + with tf.name_scope('predictions'): + # For binary classification + if inner_dim == 1: + pred = tf.cast(tf.greater(tf.squeeze(logits), 0.5), tf.int64) + # For multi-class classification + else: + pred = tf.argmax(logits, 1) + return pred + + +def _num_labels(weights): + """Number of 1's in weights. Returns 1. if 0.""" + num_labels = tf.reduce_sum(weights) + num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels) + return num_labels + + +def optimize(loss, + global_step, + max_grad_norm, + lr, + lr_decay, + sync_replicas=False, + replicas_to_aggregate=1, + task_id=0): + """Builds optimization graph. + + * Creates an optimizer, and optionally wraps with SyncReplicasOptimizer + * Computes, clips, and applies gradients + * Maintains moving averages for all trainable variables + * Summarizes variables and gradients + + Args: + loss: scalar loss to minimize. + global_step: integer scalar Variable. + max_grad_norm: float scalar. Grads will be clipped to this value. + lr: float scalar, learning rate. + lr_decay: float scalar, learning rate decay rate. + sync_replicas: bool, whether to use SyncReplicasOptimizer. + replicas_to_aggregate: int, number of replicas to aggregate when using + SyncReplicasOptimizer. + task_id: int, id of the current task; used to ensure proper initialization + of SyncReplicasOptimizer. + + Returns: + train_op + """ + with tf.name_scope('optimization'): + # Compute gradients. + tvars = tf.trainable_variables() + grads = tf.gradients( + loss, + tvars, + aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) + + # Clip non-embedding grads + non_embedding_grads_and_vars = [(g, v) for (g, v) in zip(grads, tvars) + if 'embedding' not in v.op.name] + embedding_grads_and_vars = [(g, v) for (g, v) in zip(grads, tvars) + if 'embedding' in v.op.name] + + ne_grads, ne_vars = zip(*non_embedding_grads_and_vars) + ne_grads, _ = tf.clip_by_global_norm(ne_grads, max_grad_norm) + non_embedding_grads_and_vars = zip(ne_grads, ne_vars) + + grads_and_vars = embedding_grads_and_vars + non_embedding_grads_and_vars + + # Summarize + _summarize_vars_and_grads(grads_and_vars) + + # Decaying learning rate + lr = tf.train.exponential_decay( + lr, global_step, 1, lr_decay, staircase=True) + tf.summary.scalar('learning_rate', lr) + opt = tf.train.AdamOptimizer(lr) + + # Track the moving averages of all trainable variables. + variable_averages = tf.train.ExponentialMovingAverage(0.999, global_step) + + # Apply gradients + if sync_replicas: + opt = tf.train.SyncReplicasOptimizer( + opt, + replicas_to_aggregate, + variable_averages=variable_averages, + variables_to_average=tvars, + total_num_replicas=replicas_to_aggregate) + apply_gradient_op = opt.apply_gradients( + grads_and_vars, global_step=global_step) + with tf.control_dependencies([apply_gradient_op]): + train_op = tf.no_op(name='train_op') + + # Initialization ops + tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, + opt.get_chief_queue_runner()) + if task_id == 0: # Chief task + local_init_op = opt.chief_init_op + tf.add_to_collection('chief_init_op', opt.get_init_tokens_op()) + else: + local_init_op = opt.local_step_init_op + tf.add_to_collection('local_init_op', local_init_op) + tf.add_to_collection('ready_for_local_init_op', + opt.ready_for_local_init_op) + else: + # Non-sync optimizer + variables_averages_op = variable_averages.apply(tvars) + apply_gradient_op = opt.apply_gradients(grads_and_vars, global_step) + with tf.control_dependencies([apply_gradient_op, variables_averages_op]): + train_op = tf.no_op(name='train_op') + + return train_op + + +def _summarize_vars_and_grads(grads_and_vars): + tf.logging.info('Trainable variables:') + tf.logging.info('-' * 60) + for grad, var in grads_and_vars: + tf.logging.info(var) + + def tag(name, v=var): + return v.op.name + '_' + name + + # Variable summary + mean = tf.reduce_mean(var) + tf.summary.scalar(tag('mean'), mean) + with tf.name_scope(tag('stddev')): + stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) + tf.summary.scalar(tag('stddev'), stddev) + tf.summary.scalar(tag('max'), tf.reduce_max(var)) + tf.summary.scalar(tag('min'), tf.reduce_min(var)) + tf.summary.histogram(tag('histogram'), var) + + # Gradient summary + if grad is not None: + if isinstance(grad, tf.IndexedSlices): + grad_values = grad.values + else: + grad_values = grad + + tf.summary.histogram(tag('gradient'), grad_values) + tf.summary.scalar(tag('gradient_norm'), tf.global_norm([grad_values])) + else: + tf.logging.info('Var %s has no gradient', var.op.name) diff --git a/adversarial_text/pretrain.py b/adversarial_text/pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..25d6a47669ab4e2a6042ba97a013e9e13ce26bb8 --- /dev/null +++ b/adversarial_text/pretrain.py @@ -0,0 +1,45 @@ +# Copyright 2017 Google, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Pretrains a recurrent language model. + +Computational time: + 5 days to train 100000 steps on 1 layer 1024 hidden units LSTM, + 256 embeddings, 400 truncated BP, 64 minibatch and on 4 GPU with + SyncReplicasOptimizer, that is the total minibatch is 256. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +import graphs +import train_utils + +FLAGS = tf.app.flags.FLAGS + + +def main(_): + """Trains Language Model.""" + tf.logging.set_verbosity(tf.logging.INFO) + with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): + model = graphs.get_model() + train_op, loss, global_step = model.language_model_training() + train_utils.run_training(train_op, loss, global_step) + + +if __name__ == '__main__': + tf.app.run() diff --git a/adversarial_text/train_classifier.py b/adversarial_text/train_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..94fba3f6f67330929add25beda5799e9b0cc0d2a --- /dev/null +++ b/adversarial_text/train_classifier.py @@ -0,0 +1,62 @@ +# Copyright 2017 Google, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Trains LSTM text classification model. + +Model trains with adversarial or virtual adversarial training. + +Computational time: + 6 hours to train 10000 steps without adversarial or virtual adversarial + training, on 1 layer 1024 hidden units LSTM, 256 embeddings, 400 truncated + BP, 64 minibatch and on single GPU. + + 12 hours to train 10000 steps with adversarial or virtual adversarial + training, with above condition. + +To initialize embedding and LSTM cell weights from a pretrained model, set +FLAGS.pretrained_model_dir to the pretrained model's checkpoint directory. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +import graphs +import train_utils + +flags = tf.app.flags +FLAGS = flags.FLAGS + +flags.DEFINE_string('pretrained_model_dir', None, + 'Directory path to pretrained model to restore from') + + +def main(_): + """Trains LSTM classification model.""" + tf.logging.set_verbosity(tf.logging.INFO) + with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): + model = graphs.get_model() + train_op, loss, global_step = model.classifier_training() + train_utils.run_training( + train_op, + loss, + global_step, + variables_to_restore=model.pretrained_variables, + pretrained_model_dir=FLAGS.pretrained_model_dir) + + +if __name__ == '__main__': + tf.app.run() diff --git a/adversarial_text/train_utils.py b/adversarial_text/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..91104a1352c3878b5a40cfc8563e9a314f61e6ee --- /dev/null +++ b/adversarial_text/train_utils.py @@ -0,0 +1,133 @@ +# Copyright 2017 Google, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utilities for training adversarial text models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np +import tensorflow as tf + +flags = tf.app.flags +FLAGS = flags.FLAGS + +flags.DEFINE_string('master', '', 'Master address.') +flags.DEFINE_integer('task', 0, 'Task id of the replica running the training.') +flags.DEFINE_integer('ps_tasks', 0, 'Number of parameter servers.') +flags.DEFINE_string('train_dir', '/tmp/text_train', + 'Directory for logs and checkpoints.') +flags.DEFINE_integer('max_steps', 1000000, 'Number of batches to run.') +flags.DEFINE_boolean('log_device_placement', False, + 'Whether to log device placement.') + + +def run_training(train_op, + loss, + global_step, + variables_to_restore=None, + pretrained_model_dir=None): + """Sets up and runs training loop.""" + tf.gfile.MakeDirs(FLAGS.train_dir) + + # Create pretrain Saver + if pretrained_model_dir: + assert variables_to_restore + tf.logging.info('Will attempt restore from %s: %s', pretrained_model_dir, + variables_to_restore) + saver_for_restore = tf.train.Saver(variables_to_restore) + + # Init ops + if FLAGS.sync_replicas: + local_init_op = tf.get_collection('local_init_op')[0] + ready_for_local_init_op = tf.get_collection('ready_for_local_init_op')[0] + else: + local_init_op = tf.train.Supervisor.USE_DEFAULT + ready_for_local_init_op = tf.train.Supervisor.USE_DEFAULT + + is_chief = FLAGS.task == 0 + sv = tf.train.Supervisor( + logdir=FLAGS.train_dir, + is_chief=is_chief, + save_summaries_secs=5 * 60, + save_model_secs=5 * 60, + local_init_op=local_init_op, + ready_for_local_init_op=ready_for_local_init_op, + global_step=global_step) + + # Delay starting standard services to allow possible pretrained model restore. + with sv.managed_session( + master=FLAGS.master, + config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement), + start_standard_services=False) as sess: + # Initialization + if is_chief: + if pretrained_model_dir: + maybe_restore_pretrained_model(sess, saver_for_restore, + pretrained_model_dir) + if FLAGS.sync_replicas: + sess.run(tf.get_collection('chief_init_op')[0]) + sv.start_standard_services(sess) + + sv.start_queue_runners(sess) + + # Training loop + global_step_val = 0 + while not sv.should_stop() and global_step_val < FLAGS.max_steps: + global_step_val = train_step(sess, train_op, loss, global_step) + sv.stop() + + # Final checkpoint + if is_chief: + sv.saver.save(sess, sv.save_path, global_step=global_step) + + +def maybe_restore_pretrained_model(sess, saver_for_restore, model_dir): + """Restores pretrained model if there is no ckpt model.""" + ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir) + checkpoint_exists = ckpt and ckpt.model_checkpoint_path + if checkpoint_exists: + tf.logging.info('Checkpoint exists in FLAGS.train_dir; skipping ' + 'pretraining restore') + return + + pretrain_ckpt = tf.train.get_checkpoint_state(model_dir) + if not (pretrain_ckpt and pretrain_ckpt.model_checkpoint_path): + raise ValueError( + 'Asked to restore model from %s but no checkpoint found.' % model_dir) + saver_for_restore.restore(sess, pretrain_ckpt.model_checkpoint_path) + + +def train_step(sess, train_op, loss, global_step): + """Runs a single training step.""" + start_time = time.time() + _, loss_val, global_step_val = sess.run([train_op, loss, global_step]) + duration = time.time() - start_time + + # Logging + if global_step_val % 10 == 0: + examples_per_sec = FLAGS.batch_size / duration + sec_per_batch = float(duration) + + format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') + tf.logging.info(format_str % (global_step_val, loss_val, examples_per_sec, + sec_per_batch)) + + if np.isnan(loss_val): + raise OverflowError('Loss is nan') + + return global_step_val