未验证 提交 69b01644 编写于 作者: C Chris Shallue 提交者: GitHub

Merge pull request #5546 from cshallue/master

Improvements to AstroNet and add AstroWaveNet
......@@ -40,6 +40,10 @@ Full text available at [*The Astronomical Journal*](http://iopscience.iop.org/ar
* Training and evaluating a new model.
* Using a trained model to generate new predictions.
[astrowavenet/](astrowavenet/)
* A generative model for light curves.
[light_curve_util/](light_curve_util)
* Utilities for operating on light curves. These include:
......@@ -63,11 +67,11 @@ First, ensure that you have installed the following required packages:
* **TensorFlow** ([instructions](https://www.tensorflow.org/install/))
* **Pandas** ([instructions](http://pandas.pydata.org/pandas-docs/stable/install.html))
* **NumPy** ([instructions](https://docs.scipy.org/doc/numpy/user/install.html))
* **SciPy** ([instructions](https://scipy.org/install.html))
* **AstroPy** ([instructions](http://www.astropy.org/))
* **PyDl** ([instructions](https://pypi.python.org/pypi/pydl))
* **Bazel** ([instructions](https://docs.bazel.build/versions/master/install.html))
* **Abseil Python Common Libraries** ([instructions](https://github.com/abseil/abseil-py))
* Optional: only required for unit tests.
### Optional: Run Unit Tests
......@@ -207,7 +211,7 @@ the second deepest transits).
To train a model to identify exoplanets, you will need to provide TensorFlow
with training data in
[TFRecord](https://www.tensorflow.org/guide/datasets) format. The
[TFRecord](https://www.tensorflow.org/programmers_guide/datasets) format. The
TFRecord format consists of a set of sharded files containing serialized
`tf.Example` [protocol buffers](https://developers.google.com/protocol-buffers/).
......@@ -343,7 +347,7 @@ bazel-bin/astronet/train \
--model_dir=${MODEL_DIR}
```
Optionally, you can also run a [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard)
Optionally, you can also run a [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard)
server in a separate process for real-time
monitoring of training progress and evaluation metrics.
......
......@@ -25,6 +25,7 @@ py_binary(
":models",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astronet/util:estimator_util",
],
)
......@@ -37,6 +38,7 @@ py_binary(
":models",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astronet/util:estimator_util",
],
)
......
......@@ -54,24 +54,6 @@ from astronet.astro_model import astro_model
class AstroCNNModel(astro_model.AstroModel):
"""A model for classifying light curves using a convolutional neural net."""
def __init__(self, features, labels, hparams, mode):
"""Basic setup. The actual TensorFlow graph is constructed in build().
Args:
features: A dictionary containing "time_series_features" and
"aux_features", each of which is a dictionary of named input Tensors.
All features have dtype float32 and shape [batch_size, length].
labels: An int64 Tensor with shape [batch_size]. May be None if mode is
tf.estimator.ModeKeys.PREDICT.
hparams: A ConfigDict of hyperparameters for building the model.
mode: A tf.estimator.ModeKeys to specify whether the graph should be built
for training, evaluation or prediction.
Raises:
ValueError: If mode is invalid.
"""
super(AstroCNNModel, self).__init__(features, labels, hparams, mode)
def _build_cnn_layers(self, inputs, hparams, scope="cnn"):
"""Builds convolutional layers.
......@@ -95,7 +77,7 @@ class AstroCNNModel(astro_model.AstroModel):
for i in range(hparams.cnn_num_blocks):
num_filters = int(hparams.cnn_initial_num_filters *
hparams.cnn_block_filter_factor**i)
with tf.variable_scope("block_%d" % (i + 1)):
with tf.variable_scope("block_{}".format(i + 1)):
for j in range(hparams.cnn_block_size):
net = tf.layers.conv1d(
inputs=net,
......@@ -103,7 +85,7 @@ class AstroCNNModel(astro_model.AstroModel):
kernel_size=int(hparams.cnn_kernel_size),
padding=hparams.convolution_padding,
activation=tf.nn.relu,
name="conv_%d" % (j + 1))
name="conv_{}".format(j + 1))
if hparams.pool_size > 1: # pool_size 0 or 1 denotes no pooling
net = tf.layers.max_pooling1d(
......
......@@ -35,8 +35,7 @@ class AstroCNNModelTest(tf.test.TestCase):
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can
be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape)
......
......@@ -58,24 +58,6 @@ from astronet.astro_model import astro_model
class AstroFCModel(astro_model.AstroModel):
"""A model for classifying light curves using fully connected layers."""
def __init__(self, features, labels, hparams, mode):
"""Basic setup. The actual TensorFlow graph is constructed in build().
Args:
features: A dictionary containing "time_series_features" and
"aux_features", each of which is a dictionary of named input Tensors.
All features have dtype float32 and shape [batch_size, length].
labels: An int64 Tensor with shape [batch_size]. May be None if mode is
tf.estimator.ModeKeys.PREDICT.
hparams: A ConfigDict of hyperparameters for building the model.
mode: A tf.estimator.ModeKeys to specify whether the graph should be built
for training, evaluation or prediction.
Raises:
ValueError: If mode is invalid.
"""
super(AstroFCModel, self).__init__(features, labels, hparams, mode)
def _build_local_fc_layers(self, inputs, hparams, scope):
"""Builds locally fully connected layers.
......@@ -120,8 +102,8 @@ class AstroFCModel(astro_model.AstroModel):
elif hparams.pooling_type == "avg":
net = tf.reduce_mean(net, axis=1, name="avg_pool")
else:
raise ValueError(
"Unrecognized pooling_type: %s" % hparams.pooling_type)
raise ValueError("Unrecognized pooling_type: {}".format(
hparams.pooling_type))
remaining_layers = hparams.num_local_layers - 1
else:
......@@ -133,7 +115,7 @@ class AstroFCModel(astro_model.AstroModel):
inputs=net,
num_outputs=hparams.local_layer_size,
activation_fn=tf.nn.relu,
scope="fully_connected_%d" % (i + 1))
scope="fully_connected_{}".format(i + 1))
if hparams.dropout_rate > 0:
net = tf.layers.dropout(
......
......@@ -35,8 +35,7 @@ class AstroFCModelTest(tf.test.TestCase):
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can
be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape)
......
......@@ -73,17 +73,19 @@ class AstroModel(object):
"""A TensorFlow model for classifying astrophysical light curves."""
def __init__(self, features, labels, hparams, mode):
"""Basic setup. The actual TensorFlow graph is constructed in build().
"""Basic setup.
The actual TensorFlow graph is constructed in build().
Args:
features: A dictionary containing "time_series_features" and
"aux_features", each of which is a dictionary of named input Tensors.
All features have dtype float32 and shape [batch_size, length].
"aux_features", each of which is a dictionary of named input Tensors.
All features have dtype float32 and shape [batch_size, length].
labels: An int64 Tensor with shape [batch_size]. May be None if mode is
tf.estimator.ModeKeys.PREDICT.
tf.estimator.ModeKeys.PREDICT.
hparams: A ConfigDict of hyperparameters for building the model.
mode: A tf.estimator.ModeKeys to specify whether the graph should be built
for training, evaluation or prediction.
for training, evaluation or prediction.
Raises:
ValueError: If mode is invalid.
......@@ -93,7 +95,7 @@ class AstroModel(object):
tf.estimator.ModeKeys.PREDICT
]
if mode not in valid_modes:
raise ValueError("Expected mode in %s. Got: %s" % (valid_modes, mode))
raise ValueError("Expected mode in {}. Got: {}".format(valid_modes, mode))
self.hparams = hparams
self.mode = mode
......@@ -201,10 +203,9 @@ class AstroModel(object):
if len(hidden_layers) == 1:
pre_logits_concat = hidden_layers[0][1]
else:
pre_logits_concat = tf.concat(
[layer[1] for layer in hidden_layers],
axis=1,
name="pre_logits_concat")
pre_logits_concat = tf.concat([layer[1] for layer in hidden_layers],
axis=1,
name="pre_logits_concat")
net = pre_logits_concat
with tf.variable_scope("pre_logits_hidden"):
......@@ -213,7 +214,7 @@ class AstroModel(object):
inputs=net,
units=self.hparams.pre_logits_hidden_layer_size,
activation=tf.nn.relu,
name="fully_connected_%s" % (i + 1))
name="fully_connected_{}".format(i + 1))
if self.hparams.pre_logits_dropout_rate > 0:
net = tf.layers.dropout(
......
......@@ -35,8 +35,7 @@ class AstroModelTest(tf.test.TestCase):
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can
be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape)
......
......@@ -45,6 +45,9 @@ def base():
"PC": 1, # Planet Candidate.
"AFP": 0, # Astrophysical False Positive.
"NTP": 0, # Non-Transiting Phenomenon.
"SCR1": 0, # TCE from scrambled light curve with SCR1 order.
"INV": 0, # TCE from inverted light curve.
"INJ1": 1, # Injected Planet.
},
},
# Hyperparameters for building and training the model.
......@@ -60,10 +63,10 @@ def base():
"pre_logits_dropout_rate": 0.0,
# Number of examples per training batch.
"batch_size": 64,
"batch_size": 256,
# Learning rate parameters.
"learning_rate": 1e-5,
"learning_rate": 2e-4,
"learning_rate_decay_steps": 0,
"learning_rate_decay_factor": 0,
"learning_rate_decay_staircase": True,
......
......@@ -88,7 +88,6 @@ import tensorflow as tf
from astronet.data import preprocess
parser = argparse.ArgumentParser()
_DR24_TCE_URL = ("https://exoplanetarchive.ipac.caltech.edu/cgi-bin/TblView/"
......@@ -100,7 +99,7 @@ parser.add_argument(
required=True,
help="CSV file containing the Q1-Q17 DR24 Kepler TCE table. Must contain "
"columns: rowid, kepid, tce_plnt_num, tce_period, tce_duration, "
"tce_time0bk. Download from: %s" % _DR24_TCE_URL)
"tce_time0bk. Download from: {}".format(_DR24_TCE_URL))
parser.add_argument(
"--kepler_data_dir",
......@@ -219,14 +218,16 @@ def main(argv):
for i in range(FLAGS.num_train_shards):
start = boundaries[i]
end = boundaries[i + 1]
file_shards.append((train_tces[start:end], os.path.join(
FLAGS.output_dir, "train-%.5d-of-%.5d" % (i, FLAGS.num_train_shards))))
filename = os.path.join(
FLAGS.output_dir, "train-{:05d}-of-{:05d}".format(
i, FLAGS.num_train_shards))
file_shards.append((train_tces[start:end], filename))
# Validation and test sets each have a single shard.
file_shards.append((val_tces, os.path.join(FLAGS.output_dir,
"val-00000-of-00001")))
file_shards.append((test_tces, os.path.join(FLAGS.output_dir,
"test-00000-of-00001")))
file_shards.append((val_tces,
os.path.join(FLAGS.output_dir, "val-00000-of-00001")))
file_shards.append((test_tces,
os.path.join(FLAGS.output_dir, "test-00000-of-00001")))
num_file_shards = len(file_shards)
# Launch subprocesses for the file shards.
......
......@@ -34,7 +34,7 @@ def read_light_curve(kepid, kepler_data_dir):
Args:
kepid: Kepler id of the target star.
kepler_data_dir: Base directory containing Kepler data. See
kepler_io.kepler_filenames().
kepler_io.kepler_filenames().
Returns:
all_time: A list of numpy arrays; the time values of the raw light curve.
......@@ -47,8 +47,8 @@ def read_light_curve(kepid, kepler_data_dir):
# Read the Kepler light curve.
file_names = kepler_io.kepler_filenames(kepler_data_dir, kepid)
if not file_names:
raise IOError("Failed to find .fits files in %s for Kepler ID %s" %
(kepler_data_dir, kepid))
raise IOError("Failed to find .fits files in {} for Kepler ID {}".format(
kepler_data_dir, kepid))
return kepler_io.read_kepler_light_curve(file_names)
......@@ -59,7 +59,7 @@ def process_light_curve(all_time, all_flux):
Args:
all_time: A list of numpy arrays; the time values of the raw light curve.
all_flux: A list of numpy arrays corresponding to the time arrays in
all_time.
all_time.
Returns:
time: 1D NumPy array; the time values of the light curve.
......@@ -192,7 +192,7 @@ def local_view(time,
num_bins: The number of intervals to divide the time axis into.
bin_width_factor: Width of the bins, as a fraction of duration.
num_durations: The number of durations to consider on either side of 0 (the
event is assumed to be centered at 0).
event is assumed to be centered at 0).
Returns:
1D NumPy array of size num_bins containing the median flux values of
......@@ -214,7 +214,7 @@ def generate_example_for_tce(time, flux, tce):
time: 1D NumPy array; the time values of the light curve.
flux: 1D NumPy array; the normalized flux values of the light curve.
tce: Dict-like object containing at least 'tce_period', 'tce_duration', and
'tce_time0bk'. Additional items are included as features in the output.
'tce_time0bk'. Additional items are included as features in the output.
Returns:
A tf.train.Example containing features 'global_view', 'local_view', and all
......
......@@ -26,6 +26,7 @@ import tensorflow as tf
from astronet import models
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astronet.util import estimator_util
parser = argparse.ArgumentParser()
......@@ -86,7 +87,9 @@ def main(_):
# Run evaluation. This will log the result to stderr and also write a summary
# file in the model_dir.
estimator_util.evaluate(estimator, input_fn, eval_name=FLAGS.eval_name)
eval_steps = None # Evaluate over all examples in the file.
eval_args = {FLAGS.eval_name: (input_fn, eval_steps)}
estimator_runner.evaluate(estimator, eval_args)
if __name__ == "__main__":
......
......@@ -46,7 +46,7 @@ def get_model_class(model_name):
ValueError: If model_name is unrecognized.
"""
if model_name not in _MODELS:
raise ValueError("Unrecognized model name: %s" % model_name)
raise ValueError("Unrecognized model name: {}".format(model_name))
return _MODELS[model_name][0]
......@@ -57,7 +57,7 @@ def get_model_config(model_name, config_name):
Args:
model_name: Name of the model class.
config_name: Name of a configuration-builder function from the model's
configurations module.
configurations module.
Returns:
model_class: The requested model class.
......@@ -67,11 +67,12 @@ def get_model_config(model_name, config_name):
ValueError: If model_name or config_name is unrecognized.
"""
if model_name not in _MODELS:
raise ValueError("Unrecognized model name: %s" % model_name)
raise ValueError("Unrecognized model name: {}".format(model_name))
config_module = _MODELS[model_name][1]
try:
return getattr(config_module, config_name)()
except AttributeError:
raise ValueError("Config name '%s' not found in configuration module: %s" %
(config_name, config_module.__name__))
raise ValueError(
"Config name '{}' not found in configuration module: {}".format(
config_name, config_module.__name__))
......@@ -69,7 +69,7 @@ def _recursive_pad_to_batch_size(tensor_or_collection, batch_size):
for t in tensor_or_collection
]
raise ValueError("Unknown input type: %s" % tensor_or_collection)
raise ValueError("Unknown input type: {}".format(tensor_or_collection))
def pad_dataset_to_batch_size(dataset, batch_size):
......@@ -119,7 +119,7 @@ def _recursive_set_batch_size(tensor_or_collection, batch_size):
for t in tensor_or_collection:
_recursive_set_batch_size(t, batch_size)
else:
raise ValueError("Unknown input type: %s" % tensor_or_collection)
raise ValueError("Unknown input type: {}".format(tensor_or_collection))
return tensor_or_collection
......@@ -142,19 +142,19 @@ def build_dataset(file_pattern,
Args:
file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
input_config: ConfigDict containing feature and label specifications.
batch_size: The number of examples per batch.
include_labels: Whether to read labels from the input files.
reverse_time_series_prob: If > 0, the time series features will be randomly
reversed with this probability. Within a given example, either all time
series features will be reversed, or none will be reversed.
reversed with this probability. Within a given example, either all time
series features will be reversed, or none will be reversed.
shuffle_filenames: Whether to shuffle the order of TFRecord files between
epochs.
epochs.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size.
repeat: The number of times to repeat the dataset. If None or -1 the dataset
will repeat indefinitely.
will repeat indefinitely.
use_tpu: Whether to build the dataset for TPU.
Raises:
......@@ -170,7 +170,7 @@ def build_dataset(file_pattern,
for p in file_patterns:
matches = tf.gfile.Glob(p)
if not matches:
raise ValueError("Found no input files matching %s" % p)
raise ValueError("Found no input files matching {}".format(p))
filenames.extend(matches)
tf.logging.info("Building input pipeline from %d files matching patterns: %s",
len(filenames), file_patterns)
......@@ -180,8 +180,8 @@ def build_dataset(file_pattern,
label_ids = set(input_config.label_map.values())
if label_ids != set(range(len(label_ids))):
raise ValueError(
"Label IDs must be contiguous integers starting at 0. Got: %s" %
label_ids)
"Label IDs must be contiguous integers starting at 0. Got: {}".format(
label_ids))
# Create a HashTable mapping label strings to integer ids.
table_initializer = tf.contrib.lookup.KeyValueTensorInitializer(
......
......@@ -48,11 +48,6 @@ class DatasetOpsTest(tf.test.TestCase):
self.assertEqual([5], tensor_1d.shape)
self.assertAllEqual([0, 1, 2, 3, 4], tensor_1d.eval())
# Invalid to pad Tensor with batch size 5 to batch size 3.
tensor_1d_pad3 = dataset_ops.pad_tensor_to_batch_size(tensor_1d, 3)
with self.assertRaises(tf.errors.InvalidArgumentError):
tensor_1d_pad3.eval()
tensor_1d_pad5 = dataset_ops.pad_tensor_to_batch_size(tensor_1d, 5)
self.assertEqual([5], tensor_1d_pad5.shape)
self.assertAllEqual([0, 1, 2, 3, 4], tensor_1d_pad5.eval())
......@@ -66,11 +61,6 @@ class DatasetOpsTest(tf.test.TestCase):
self.assertEqual([3, 3], tensor_2d.shape)
self.assertAllEqual([[0, 1, 2], [3, 4, 5], [6, 7, 8]], tensor_2d.eval())
tensor_2d_pad2 = dataset_ops.pad_tensor_to_batch_size(tensor_2d, 2)
# Invalid to pad Tensor with batch size 2 to batch size 2.
with self.assertRaises(tf.errors.InvalidArgumentError):
tensor_2d_pad2.eval()
tensor_2d_pad3 = dataset_ops.pad_tensor_to_batch_size(tensor_2d, 3)
self.assertEqual([3, 3], tensor_2d_pad3.shape)
self.assertAllEqual([[0, 1, 2], [3, 4, 5], [6, 7, 8]],
......
......@@ -27,11 +27,10 @@ def prepare_feed_dict(model, features, labels=None, is_training=None):
Args:
model: An instance of AstroModel.
features: Dictionary containing "time_series_features" and "aux_features".
Each is a dictionary of named numpy arrays of shape
[batch_size, length].
Each is a dictionary of named numpy arrays of shape [batch_size, length].
labels: (Optional). Numpy array of shape [batch_size].
is_training: (Optional). Python boolean to feed to the model.is_training
Tensor (if None, no value is fed).
Tensor (if None, no value is fed).
Returns:
feed_dict: A dictionary of input Tensor to numpy array.
......
......@@ -31,9 +31,9 @@ class InputOpsTest(tf.test.TestCase):
Args:
expected_shapes: Dictionary of expected Tensor shapes, as lists,
corresponding to the structure of 'features'.
corresponding to the structure of 'features'.
features: Dictionary of feature placeholders of the format returned by
input_ops.build_feature_placeholders().
input_ops.build_feature_placeholders().
"""
actual_shapes = {}
for feature_type in features:
......
......@@ -30,7 +30,7 @@ def _metric_variable(name, shape, dtype):
collections=[tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.METRIC_VARIABLES])
def _build_metrics(labels, predictions, weights, batch_losses):
def _build_metrics(labels, predictions, weights, batch_losses, output_dim=1):
"""Builds TensorFlow operations to compute model evaluation metrics.
Args:
......@@ -38,14 +38,16 @@ def _build_metrics(labels, predictions, weights, batch_losses):
predictions: Tensor with shape [batch_size, output_dim].
weights: Tensor with shape [batch_size].
batch_losses: Tensor with shape [batch_size].
output_dim: Dimension of model output
Returns:
A dictionary {metric_name: (metric_value, update_op).
"""
# Compute the predicted labels.
assert len(predictions.shape) == 2
binary_classification = (predictions.shape[1] == 1)
binary_classification = output_dim == 1
if binary_classification:
assert predictions.shape[1] == 1
predictions = tf.squeeze(predictions, axis=[1])
predicted_labels = tf.to_int32(
tf.greater(predictions, 0.5), name="predicted_labels")
......@@ -73,35 +75,31 @@ def _build_metrics(labels, predictions, weights, batch_losses):
metrics["losses/weighted_cross_entropy"] = tf.metrics.mean(
batch_losses, weights=weights, name="cross_entropy_loss")
# Possibly create additional metrics for binary classification.
def _count_condition(name, labels_value, predicted_value):
"""Creates a counter for given values of predictions and labels."""
count = _metric_variable(name, [], tf.float32)
is_equal = tf.to_float(
tf.logical_and(
tf.equal(labels, labels_value),
tf.equal(predicted_labels, predicted_value)))
update_op = tf.assign_add(count, tf.reduce_sum(weights * is_equal))
return count.read_value(), update_op
# Confusion matrix metrics.
num_labels = 2 if binary_classification else output_dim
for gold_label in range(num_labels):
for pred_label in range(num_labels):
metric_name = "confusion_matrix/label_{}_pred_{}".format(
gold_label, pred_label)
metrics[metric_name] = _count_condition(
metric_name, labels_value=gold_label, predicted_value=pred_label)
# Possibly create AUC metric for binary classification.
if binary_classification:
labels = tf.cast(labels, dtype=tf.bool)
predicted_labels = tf.cast(predicted_labels, dtype=tf.bool)
# AUC.
metrics["auc"] = tf.metrics.auc(
labels, predictions, weights=weights, num_thresholds=1000)
def _count_condition(name, labels_value, predicted_value):
"""Creates a counter for given values of predictions and labels."""
count = _metric_variable(name, [], tf.float32)
is_equal = tf.to_float(
tf.logical_and(
tf.equal(labels, labels_value),
tf.equal(predicted_labels, predicted_value)))
update_op = tf.assign_add(count, tf.reduce_sum(weights * is_equal))
return count.read_value(), update_op
# Confusion matrix metrics.
metrics["confusion_matrix/true_positives"] = _count_condition(
"true_positives", labels_value=True, predicted_value=True)
metrics["confusion_matrix/false_positives"] = _count_condition(
"false_positives", labels_value=False, predicted_value=True)
metrics["confusion_matrix/true_negatives"] = _count_condition(
"true_negatives", labels_value=False, predicted_value=False)
metrics["confusion_matrix/false_negatives"] = _count_condition(
"false_negatives", labels_value=True, predicted_value=False)
return metrics
......@@ -130,7 +128,12 @@ def create_metric_fn(model):
}
def metric_fn(labels, predictions, weights, batch_losses):
return _build_metrics(labels, predictions, weights, batch_losses)
return _build_metrics(
labels,
predictions,
weights,
batch_losses,
output_dim=model.hparams.output_dim)
return metric_fn, metric_fn_inputs
......
......@@ -30,15 +30,23 @@ def _unpack_metric_map(names_to_tuples):
return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
class _MockHparams(object):
"""Mock Hparams class to support accessing with dot notation."""
pass
class _MockModel(object):
"""Mock model for testing."""
def __init__(self, labels, predictions, weights, batch_losses):
def __init__(self, labels, predictions, weights, batch_losses, output_dim):
self.labels = tf.constant(labels, dtype=tf.int32)
self.predictions = tf.constant(predictions, dtype=tf.float32)
self.weights = None if weights is None else tf.constant(
weights, dtype=tf.float32)
self.batch_losses = tf.constant(batch_losses, dtype=tf.float32)
self.hparams = _MockHparams()
self.hparams.output_dim = output_dim
class MetricsTest(tf.test.TestCase):
......@@ -48,13 +56,13 @@ class MetricsTest(tf.test.TestCase):
predictions = [
[0.7, 0.2, 0.1, 0.0], # Predicted label = 0
[0.2, 0.4, 0.2, 0.2], # Predicted label = 1
[0.0, 0.0, 0.0, 1.0], # Predicted label = 4
[0.1, 0.1, 0.7, 0.1], # Predicted label = 3
[0.0, 0.0, 0.0, 1.0], # Predicted label = 3
[0.1, 0.1, 0.7, 0.1], # Predicted label = 2
]
weights = None
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses)
model = _MockModel(labels, predictions, weights, batch_losses, output_dim=4)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
......@@ -68,6 +76,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct": 2,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"confusion_matrix/label_0_pred_0": 1,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 1,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 1,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 1,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops))
sess.run(update_ops)
......@@ -76,6 +100,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct": 4,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"confusion_matrix/label_0_pred_0": 2,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 2,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 2,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 2,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops))
def testMultiClassificationWithWeights(self):
......@@ -83,13 +123,13 @@ class MetricsTest(tf.test.TestCase):
predictions = [
[0.7, 0.2, 0.1, 0.0], # Predicted label = 0
[0.2, 0.4, 0.2, 0.2], # Predicted label = 1
[0.0, 0.0, 0.0, 1.0], # Predicted label = 4
[0.1, 0.1, 0.7, 0.1], # Predicted label = 3
[0.0, 0.0, 0.0, 1.0], # Predicted label = 3
[0.1, 0.1, 0.7, 0.1], # Predicted label = 2
]
weights = [0, 1, 0, 1]
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses)
model = _MockModel(labels, predictions, weights, batch_losses, output_dim=4)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
......@@ -103,6 +143,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct": 1,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 1,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 0,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 1,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops))
sess.run(update_ops)
......@@ -111,6 +167,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct": 2,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 2,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 0,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 2,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops))
def testBinaryClassificationWithoutWeights(self):
......@@ -124,7 +196,7 @@ class MetricsTest(tf.test.TestCase):
weights = None
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses)
model = _MockModel(labels, predictions, weights, batch_losses, output_dim=1)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
......@@ -139,10 +211,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"auc": 0.25,
"confusion_matrix/true_positives": 1,
"confusion_matrix/true_negatives": 1,
"confusion_matrix/false_positives": 1,
"confusion_matrix/false_negatives": 1,
"confusion_matrix/label_0_pred_0": 1,
"confusion_matrix/label_0_pred_1": 1,
"confusion_matrix/label_1_pred_0": 1,
"confusion_matrix/label_1_pred_1": 1,
}, sess.run(value_ops))
sess.run(update_ops)
......@@ -152,10 +224,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"auc": 0.25,
"confusion_matrix/true_positives": 2,
"confusion_matrix/true_negatives": 2,
"confusion_matrix/false_positives": 2,
"confusion_matrix/false_negatives": 2,
"confusion_matrix/label_0_pred_0": 2,
"confusion_matrix/label_0_pred_1": 2,
"confusion_matrix/label_1_pred_0": 2,
"confusion_matrix/label_1_pred_1": 2,
}, sess.run(value_ops))
def testBinaryClassificationWithWeights(self):
......@@ -169,7 +241,7 @@ class MetricsTest(tf.test.TestCase):
weights = [0, 1, 0, 1]
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses)
model = _MockModel(labels, predictions, weights, batch_losses, output_dim=1)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
......@@ -184,10 +256,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"auc": 0,
"confusion_matrix/true_positives": 1,
"confusion_matrix/true_negatives": 0,
"confusion_matrix/false_positives": 1,
"confusion_matrix/false_negatives": 0,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 1,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 1,
}, sess.run(value_ops))
sess.run(update_ops)
......@@ -197,10 +269,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"auc": 0,
"confusion_matrix/true_positives": 2,
"confusion_matrix/true_negatives": 0,
"confusion_matrix/false_positives": 2,
"confusion_matrix/false_negatives": 0,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 2,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 2,
}, sess.run(value_ops))
......
......@@ -47,15 +47,10 @@ def fake_features(feature_spec, batch_size):
Dictionary containing "time_series_features" and "aux_features". Each is a
dictionary of named numpy arrays of shape [batch_size, length].
"""
features = {}
features["time_series_features"] = {
name: np.random.random([batch_size, spec["length"]])
for name, spec in feature_spec.items() if spec["is_time_series"]
}
features["aux_features"] = {
name: np.random.random([batch_size, spec["length"]])
for name, spec in feature_spec.items() if not spec["is_time_series"]
}
features = {"time_series_features": {}, "aux_features": {}}
for name, spec in feature_spec.items():
ftype = "time_series_features" if spec["is_time_series"] else "aux_features"
features[ftype][name] = np.random.random([batch_size, spec["length"]])
return features
......
......@@ -51,7 +51,7 @@ def create_optimizer(hparams, learning_rate, use_tpu=False):
hparams: ConfigDict containing the optimizer configuration.
learning_rate: A Python float or a scalar Tensor.
use_tpu: If True, the returned optimizer is wrapped in a
CrossShardOptimizer.
CrossShardOptimizer.
Returns:
A TensorFlow optimizer.
......@@ -74,7 +74,7 @@ def create_optimizer(hparams, learning_rate, use_tpu=False):
elif optimizer_name == "rmsprop":
optimizer = tf.RMSPropOptimizer(learning_rate)
else:
raise ValueError("Unknown optimizer: %s" % hparams.optimizer)
raise ValueError("Unknown optimizer: {}".format(hparams.optimizer))
if use_tpu:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
......
......@@ -26,6 +26,7 @@ import tensorflow as tf
from astronet import models
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astronet.util import estimator_util
parser = argparse.ArgumentParser()
......@@ -112,11 +113,14 @@ def main(_):
file_pattern=FLAGS.eval_files,
input_config=config.inputs,
mode=tf.estimator.ModeKeys.EVAL)
eval_args = {
"val": (eval_input_fn, None) # eval_name: (input_fn, eval_steps)
}
for _ in estimator_util.continuous_train_and_eval(
for _ in estimator_runner.continuous_train_and_eval(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
eval_args=eval_args,
train_steps=FLAGS.train_steps):
# continuous_train_and_eval() yields evaluation metrics after each
# training epoch. We don't do anything here.
......
......@@ -32,6 +32,12 @@ py_test(
deps = [":config_util"],
)
py_library(
name = "estimator_runner",
srcs = ["estimator_runner.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "estimator_util",
srcs = ["estimator_util.py"],
......@@ -47,6 +53,7 @@ py_library(
name = "example_util",
srcs = ["example_util.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)
py_test(
......
......@@ -49,19 +49,28 @@ def parse_json(json_string_or_file):
with tf.gfile.Open(json_string_or_file) as f:
json_dict = json.load(f)
except ValueError as json_file_parsing_error:
raise ValueError("Unable to parse the content of the json file %s. "
"Parsing error: %s." % (json_string_or_file,
json_file_parsing_error.message))
raise ValueError("Unable to parse the content of the json file {}. "
"Parsing error: {}.".format(
json_string_or_file,
json_file_parsing_error.message))
except tf.gfile.FileError:
message = ("Unable to parse the input parameter neither as literal "
"JSON nor as the name of a file that exists.\n"
"JSON parsing error: %s\n\n Input parameter:\n%s." %
(literal_json_parsing_error.message, json_string_or_file))
"JSON parsing error: {}\n\n Input parameter:\n{}.".format(
literal_json_parsing_error.message, json_string_or_file))
raise ValueError(message)
return json_dict
def to_json(config):
"""Converts a JSON-serializable configuration object to a JSON string."""
if hasattr(config, "to_json") and callable(config.to_json):
return config.to_json(indent=2)
else:
return json.dumps(config, indent=2)
def log_and_save_config(config, output_dir):
"""Logs and writes a JSON-serializable configuration object.
......@@ -69,10 +78,7 @@ def log_and_save_config(config, output_dir):
config: A JSON-serializable object.
output_dir: Destination directory.
"""
if hasattr(config, "to_json") and callable(config.to_json):
config_json = config.to_json(indent=2)
else:
config_json = json.dumps(config, indent=2)
config_json = to_json(config)
tf.logging.info("config: %s", config_json)
tf.gfile.MakeDirs(output_dir)
......@@ -104,7 +110,7 @@ def unflatten(flat_config):
Args:
flat_config: A dictionary with strings as keys where nested configuration
parameters are represented with period-separated names.
parameters are represented with period-separated names.
Returns:
A dictionary nested according to the keys of the input dictionary.
......
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for training and evaluation using a TensorFlow Estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def evaluate(estimator, eval_args):
"""Runs evaluation on the latest model checkpoint.
Args:
estimator: Instance of tf.Estimator.
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where eval_name
is the name of the evaluation set (e.g. "train" or "val"), input_fn is an
input function returning a tuple (features, labels), and eval_steps is the
number of steps for which to evaluate the model (if None, evaluates until
input_fn raises an end-of-input exception).
Returns:
global_step: The global step of the checkpoint evaluated.
values: A dict of metric values from the evaluation. May be empty, e.g. if
the training job has not yet saved a checkpoint or the checkpoint is
deleted by the time the TPU worker initializes.
"""
# Default return values if evaluation fails.
global_step = None
values = {}
latest_checkpoint = estimator.latest_checkpoint()
if not latest_checkpoint:
# This is expected if the training job has not yet saved a checkpoint.
return global_step, values
tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint)
try:
for eval_name, (input_fn, eval_steps) in eval_args.items():
values[eval_name] = estimator.evaluate(
input_fn, steps=eval_steps, name=eval_name)
if global_step is None:
global_step = values[eval_name].get("global_step")
except (tf.errors.NotFoundError, ValueError):
# Expected under some conditions, e.g. checkpoint is already deleted by the
# trainer process. Increasing RunConfig.keep_checkpoint_max may prevent this
# in some cases.
tf.logging.info("Checkpoint %s no longer exists, skipping evaluation.",
latest_checkpoint)
return global_step, values
def continuous_eval(estimator,
eval_args,
train_steps=None,
timeout_secs=None,
timeout_fn=None):
"""Runs evaluation whenever there's a new checkpoint.
Args:
estimator: Instance of tf.Estimator.
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where eval_name
is the name of the evaluation set (e.g. "train" or "val"), input_fn is an
input function returning a tuple (features, labels), and eval_steps is the
number of steps for which to evaluate the model (if None, evaluates until
input_fn raises an end-of-input exception).
train_steps: The number of steps the model will train for. This function
will terminate once the model has finished training.
timeout_secs: Number of seconds to wait for new checkpoints. If None, wait
indefinitely.
timeout_fn: Optional function to call after timeout. The iterator will exit
if and only if the function returns True.
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
for _ in tf.contrib.training.checkpoints_iterator(
estimator.model_dir, timeout=timeout_secs, timeout_fn=timeout_fn):
global_step, values = evaluate(estimator, eval_args)
yield global_step, values
global_step = global_step or 0 # Ensure global_step is not None.
if train_steps and global_step >= train_steps:
break
def continuous_train_and_eval(estimator,
train_input_fn,
eval_args,
local_eval_frequency=None,
train_hooks=None,
train_steps=None):
"""Alternates training and evaluation.
Args:
estimator: Instance of tf.Estimator.
train_input_fn: Input function returning a tuple (features, labels).
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where eval_name
is the name of the evaluation set (e.g. "train" or "val"), input_fn is an
input function returning a tuple (features, labels), and eval_steps is the
number of steps for which to evaluate the model (if None, evaluates until
input_fn raises an end-of-input exception).
local_eval_frequency: The number of training steps between evaluations. If
None, trains until train_input_fn raises an end-of-input exception.
train_hooks: List of SessionRunHook subclass instances. Used for callbacks
inside the training call.
train_steps: The total number of steps to train the model for.
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
while True:
# We run evaluation before training in this loop to prevent evaluation from
# being skipped if the process is interrupted.
global_step, values = evaluate(estimator, eval_args)
yield global_step, values
global_step = global_step or 0 # Ensure global_step is not None.
if train_steps and global_step >= train_steps:
break
# Decide how many steps before the next evaluation.
steps = local_eval_frequency
if train_steps:
remaining_steps = train_steps - global_step
steps = min(steps, remaining_steps) if steps else remaining_steps
tf.logging.info("Starting training at global step %d", global_step)
estimator.train(train_input_fn, hooks=train_hooks, steps=steps)
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for training models with the TensorFlow Estimator API."""
"""Helper functions for creating a TensorFlow Estimator."""
from __future__ import absolute_import
from __future__ import division
......@@ -27,71 +27,104 @@ from astronet.ops import metrics
from astronet.ops import training
def create_input_fn(file_pattern,
input_config,
mode,
shuffle_values_buffer=0,
repeat=1):
"""Creates an input_fn that reads a dataset from sharded TFRecord files.
class _InputFn(object):
"""Class that acts as a callable input function for Estimator train / eval."""
Args:
file_pattern: File pattern matching input TFRecord files, e.g.
def __init__(self,
file_pattern,
input_config,
mode,
shuffle_values_buffer=0,
repeat=1):
"""Initializes the input function.
Args:
file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
input_config: ConfigDict containing feature and label specifications.
mode: A tf.estimator.ModeKeys.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size.
repeat: The number of times to repeat the dataset. If None or -1 the
input_config: ConfigDict containing feature and label specifications.
mode: A tf.estimator.ModeKeys.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this
size.
repeat: The number of times to repeat the dataset. If None or -1 the
elements will be repeated indefinitely.
Returns:
A callable that builds an input pipeline and returns (features, labels).
"""
include_labels = (
mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL])
reverse_time_series_prob = 0.5 if mode == tf.estimator.ModeKeys.TRAIN else 0
shuffle_filenames = (mode == tf.estimator.ModeKeys.TRAIN)
def input_fn(config, params):
"""Builds an input pipeline that reads a dataset from TFRecord files."""
"""
self._file_pattern = file_pattern
self._input_config = input_config
self._mode = mode
self._shuffle_values_buffer = shuffle_values_buffer
self._repeat = repeat
def __call__(self, config, params):
"""Builds the input pipeline."""
# Infer whether this input_fn was called by Estimator or TPUEstimator using
# the config type.
use_tpu = isinstance(config, tf.contrib.tpu.RunConfig)
mode = self._mode
include_labels = (
mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL])
reverse_time_series_prob = 0.5 if mode == tf.estimator.ModeKeys.TRAIN else 0
shuffle_filenames = (mode == tf.estimator.ModeKeys.TRAIN)
dataset = dataset_ops.build_dataset(
file_pattern=file_pattern,
input_config=input_config,
file_pattern=self._file_pattern,
input_config=self._input_config,
batch_size=params["batch_size"],
include_labels=include_labels,
reverse_time_series_prob=reverse_time_series_prob,
shuffle_filenames=shuffle_filenames,
shuffle_values_buffer=shuffle_values_buffer,
repeat=repeat,
shuffle_values_buffer=self._shuffle_values_buffer,
repeat=self._repeat,
use_tpu=use_tpu)
return dataset
return input_fn
def create_model_fn(model_class, hparams, use_tpu=False):
"""Wraps model_class as an Estimator or TPUEstimator model_fn.
def create_input_fn(file_pattern,
input_config,
mode,
shuffle_values_buffer=0,
repeat=1):
"""Creates an input_fn that reads a dataset from sharded TFRecord files.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
use_tpu: If True, a TPUEstimator model_fn is returned. Otherwise an
Estimator model_fn is returned.
file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
input_config: ConfigDict containing feature and label specifications.
mode: A tf.estimator.ModeKeys.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size.
repeat: The number of times to repeat the dataset. If None or -1 the
elements will be repeated indefinitely.
Returns:
model_fn: A callable that constructs the model and returns a
TPUEstimatorSpec if use_tpu is True, otherwise an EstimatorSpec.
A callable that builds the input pipeline and returns a tf.data.Dataset
object.
"""
hparams = copy.deepcopy(hparams)
return _InputFn(file_pattern, input_config, mode, shuffle_values_buffer,
repeat)
def model_fn(features, labels, mode, params):
class _ModelFn(object):
"""Class that acts as a callable model function for Estimator train / eval."""
def __init__(self, model_class, hparams, use_tpu=False):
"""Initializes the model function.
Args:
model_class: Model class.
hparams: ConfigDict containing hyperparameters for building and training
the model.
use_tpu: If True, a TPUEstimator will be returned. Otherwise an Estimator
will be returned.
"""
self._model_class = model_class
self._base_hparams = hparams
self._use_tpu = use_tpu
def __call__(self, features, labels, mode, params):
"""Builds the model and returns an EstimatorSpec or TPUEstimatorSpec."""
# For TPUEstimator, params contains the batch size per TPU core.
hparams = copy.deepcopy(self._base_hparams)
if "batch_size" in params:
hparams.batch_size = params["batch_size"]
......@@ -99,14 +132,15 @@ def create_model_fn(model_class, hparams, use_tpu=False):
if "labels" in features:
if labels is not None and labels is not features["labels"]:
raise ValueError(
"Conflicting labels: features['labels'] = %s, labels = %s" %
(features["labels"], labels))
"Conflicting labels: features['labels'] = {}, labels = {}".format(
features["labels"], labels))
labels = features.pop("labels")
model = model_class(features, labels, hparams, mode)
model = self._model_class(features, labels, hparams, mode)
model.build()
# Possibly create train_op.
use_tpu = self._use_tpu
train_op = None
if mode == tf.estimator.ModeKeys.TRAIN:
learning_rate = training.create_learning_rate(hparams, model.global_step)
......@@ -137,7 +171,21 @@ def create_model_fn(model_class, hparams, use_tpu=False):
return estimator
return model_fn
def create_model_fn(model_class, hparams, use_tpu=False):
"""Wraps model_class as an Estimator or TPUEstimator model_fn.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
use_tpu: If True, a TPUEstimator model_fn is returned. Otherwise an
Estimator model_fn is returned.
Returns:
model_fn: A callable that constructs the model and returns a
TPUEstimatorSpec if use_tpu is True, otherwise an EstimatorSpec.
"""
return _ModelFn(model_class, hparams, use_tpu)
def create_estimator(model_class,
......@@ -155,10 +203,10 @@ def create_estimator(model_class,
hparams: ConfigDict of configuration parameters for building the model.
run_config: Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig.
model_dir: Optional directory for saving the model. If not passed
explicitly, it must be specified in run_config.
explicitly, it must be specified in run_config.
eval_batch_size: Optional batch size for evaluation on TPU. Only applicable
if run_config is a tf.contrib.tpu.RunConfig. Defaults to
hparams.batch_size.
if run_config is a tf.contrib.tpu.RunConfig. Defaults to
hparams.batch_size.
Returns:
An Estimator object if run_config is None or a tf.estimator.RunConfig, or a
......@@ -202,117 +250,3 @@ def create_estimator(model_class,
params={"batch_size": hparams.batch_size})
return estimator
def evaluate(estimator, input_fn, eval_steps=None, eval_name="val"):
"""Runs evaluation on the latest model checkpoint.
Args:
estimator: Instance of tf.Estimator.
input_fn: Input function returning a tuple (features, labels).
eval_steps: The number of steps for which to evaluate the model. If None,
evaluates until input_fn raises an end-of-input exception.
eval_name: Name of the evaluation set, e.g. "train" or "val".
Returns:
A dict of metric values from the evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
values = {} # Default return value if evaluation fails.
latest_checkpoint = tf.train.latest_checkpoint(estimator.model_dir)
if not latest_checkpoint:
# This is expected if the training job has not yet saved a checkpoint.
return values
tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint)
try:
values = estimator.evaluate(input_fn, steps=eval_steps, name=eval_name)
except tf.errors.NotFoundError:
# Expected under some conditions, e.g. TPU worker does not finish
# initializing until long after the CPU job tells it to start evaluating
# and the checkpoint file is deleted already.
tf.logging.info("Checkpoint %s no longer exists, skipping evaluation",
latest_checkpoint)
return values
def continuous_eval(estimator,
input_fn,
train_steps=None,
eval_steps=None,
eval_name="val"):
"""Runs evaluation whenever there's a new checkpoint.
Args:
estimator: Instance of tf.Estimator.
input_fn: Input function returning a tuple (features, labels).
train_steps: The number of steps the model will train for. This function
will terminate once the model has finished training. If None, this
function will run forever.
eval_steps: The number of steps for which to evaluate the model. If None,
evaluates until input_fn raises an end-of-input exception.
eval_name: Name of the evaluation set, e.g. "train" or "val".
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
for _ in tf.contrib.training.checkpoints_iterator(estimator.model_dir):
values = evaluate(estimator, input_fn, eval_steps, eval_name)
yield values
global_step = values.get("global_step", 0)
if train_steps and global_step >= train_steps:
break
def continuous_train_and_eval(estimator,
train_input_fn,
eval_input_fn,
local_eval_frequency=None,
train_hooks=None,
train_steps=None,
eval_steps=None,
eval_name="val"):
"""Alternates training and evaluation.
Args:
estimator: Instance of tf.Estimator.
train_input_fn: Input function returning a tuple (features, labels).
eval_input_fn: Input function returning a tuple (features, labels).
local_eval_frequency: The number of training steps between evaluations. If
None, trains until train_input_fn raises an end-of-input exception.
train_hooks: List of SessionRunHook subclass instances. Used for callbacks
inside the training call.
train_steps: The total number of steps to train the model for.
eval_steps: The number of steps for which to evaluate the model. If None,
evaluates until eval_input_fn raises an end-of-input exception.
eval_name: Name of the evaluation set, e.g. "train" or "val".
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
while True:
# We run evaluation before training in this loop to prevent evaluation from
# being skipped if the process is interrupted.
values = evaluate(estimator, eval_input_fn, eval_steps, eval_name)
yield values
global_step = values.get("global_step", 0)
if train_steps and global_step >= train_steps:
break
# Decide how many steps before the next evaluation.
steps = local_eval_frequency
if train_steps:
remaining_steps = train_steps - global_step
steps = min(steps, remaining_steps) if steps else remaining_steps
tf.logging.info("Starting training at global step %d", global_step)
estimator.train(train_input_fn, hooks=train_hooks, steps=steps)
......@@ -28,7 +28,7 @@ def get_feature(ex, name, kind=None, strict=True):
ex: A tf.train.Example.
name: Name of the feature to look up.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified.
not specified.
strict: Whether to raise a KeyError if there is no such feature.
Returns:
......@@ -48,7 +48,8 @@ def get_feature(ex, name, kind=None, strict=True):
return np.array([]) # Feature exists, but it's empty.
if kind and kind != inferred_kind:
raise TypeError("Requested %s, but Feature has %s" % (kind, inferred_kind))
raise TypeError("Requested {}, but Feature has {}".format(
kind, inferred_kind))
return np.array(getattr(ex.features.feature[name], inferred_kind).value)
......@@ -79,7 +80,12 @@ def _infer_kind(value):
return "bytes_list"
def set_feature(ex, name, value, kind=None, allow_overwrite=False):
def set_feature(ex,
name,
value,
kind=None,
allow_overwrite=False,
bytes_encoding="latin-1"):
"""Sets a feature value in a tf.train.Example.
Args:
......@@ -87,8 +93,9 @@ def set_feature(ex, name, value, kind=None, allow_overwrite=False):
name: Name of the feature to set.
value: Feature value to set. Must be a sequence.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified.
not specified.
allow_overwrite: Whether to overwrite the existing value of the feature.
bytes_encoding: Codec for encoding strings when kind = 'bytes_list'.
Raises:
ValueError: If `allow_overwrite` is False and the feature already exists, or
......@@ -99,19 +106,20 @@ def set_feature(ex, name, value, kind=None, allow_overwrite=False):
del ex.features.feature[name]
else:
raise ValueError(
"Attempting to set duplicate feature with name: %s" % name)
"Attempting to overwrite feature with name: {}. "
"Set allow_overwrite=True if this is desired.".format(name))
if not kind:
kind = _infer_kind(value)
if kind == "bytes_list":
value = [str(v).encode("latin-1") for v in value]
value = [str(v).encode(bytes_encoding) for v in value]
elif kind == "float_list":
value = [float(v) for v in value]
elif kind == "int64_list":
value = [int(v) for v in value]
else:
raise ValueError("Unrecognized kind: %s" % kind)
raise ValueError("Unrecognized kind: {}".format(kind))
getattr(ex.features.feature[name], kind).value.extend(value)
......@@ -121,9 +129,13 @@ def set_float_feature(ex, name, value, allow_overwrite=False):
set_feature(ex, name, value, "float_list", allow_overwrite)
def set_bytes_feature(ex, name, value, allow_overwrite=False):
def set_bytes_feature(ex,
name,
value,
allow_overwrite=False,
bytes_encoding="latin-1"):
"""Sets the value of a bytes feature in a tf.train.Example."""
set_feature(ex, name, value, "bytes_list", allow_overwrite)
set_feature(ex, name, value, "bytes_list", allow_overwrite, bytes_encoding)
def set_int64_feature(ex, name, value, allow_overwrite=False):
......
"""A TensorFlow model for generative modeling of light curves."""
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_binary(
name = "trainer",
srcs = ["trainer.py"],
srcs_version = "PY2AND3",
deps = [
":astrowavenet_model",
":configurations",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astrowavenet/data:kepler_light_curves",
"//astrowavenet/data:synthetic_transits",
"//astrowavenet/util:estimator_util",
],
)
py_library(
name = "configurations",
srcs = ["configurations.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "astrowavenet_model",
srcs = [
"astrowavenet_model.py",
],
srcs_version = "PY2AND3",
)
py_test(
name = "astrowavenet_model_test",
size = "small",
srcs = [
"astrowavenet_model_test.py",
],
srcs_version = "PY2AND3",
deps = [
":astrowavenet_model",
":configurations",
"//astronet/util:configdict",
],
)
# AstroWaveNet: A generative model for light curves.
Implementation based on "WaveNet: A Generative Model of Raw Audio":
https://arxiv.org/abs/1609.03499
## Code Authors
Alex Tamkin: [@atamkin](https://github.com/atamkin)
Chris Shallue: [@cshallue](https://github.com/cshallue)
## Pull Requests / Issues
Chris Shallue: [@cshallue](https://github.com/cshallue)
## Additional Dependencies
This package requires TensorFlow 1.12 or greater. As of October 2018, this
requires the **TensorFlow nightly build**
([instructions](https://www.tensorflow.org/install/pip)).
In addition to the dependencies listed in the top-level README, this package
requires:
* **TensorFlow Probability** ([instructions](https://www.tensorflow.org/probability/install))
* **Six** ([instructions](https://pypi.org/project/six/))
## Basic Usage
To train a model on synthetic transits:
```bash
bazel build astrowavenet/...
```
```bash
bazel-bin/astrowavenet/trainer \
--dataset=synthetic_transits \
--model_dir=/tmp/astrowavenet/ \
--config_overrides='{"hparams": {"batch_size": 16, "num_residual_blocks": 2}}' \
--schedule=train_and_eval \
--eval_steps=100 \
--save_checkpoints_steps=1000
```
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A TensorFlow WaveNet model for generative modeling of light curves.
Implementation based on "WaveNet: A Generative Model of Raw Audio":
https://arxiv.org/abs/1609.03499
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow_probability as tfp
def _shift_right(x):
"""Shifts the input Tensor right by one index along the second dimension.
Pads the front with zeros and discards the last element.
Args:
x: Input three-dimensional tf.Tensor.
Returns:
Padded, shifted tensor of same shape as input.
"""
x_padded = tf.pad(x, [[0, 0], [1, 0], [0, 0]])
return x_padded[:, :-1, :]
class AstroWaveNet(object):
"""A TensorFlow model for generative modeling of light curves."""
def __init__(self, features, hparams, mode):
"""Basic setup.
The actual TensorFlow graph is constructed in build().
Args:
features: A dictionary containing "autoregressive_input" and
"conditioning_stack", each of which is a named input Tensor. All
features have dtype float32 and shape [batch_size, length, dim].
hparams: A ConfigDict of hyperparameters for building the model.
mode: A tf.estimator.ModeKeys to specify whether the graph should be built
for training, evaluation or prediction.
Raises:
ValueError: If mode is invalid.
"""
valid_modes = [
tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
tf.estimator.ModeKeys.PREDICT
]
if mode not in valid_modes:
raise ValueError("Expected mode in {}. Got: {}".format(valid_modes, mode))
self.hparams = hparams
self.mode = mode
self.autoregressive_input = features["autoregressive_input"]
self.conditioning_stack = features["conditioning_stack"]
self.weights = features.get("weights")
self.network_output = None # Sum of skip connections from dilation stack.
self.dist_params = None # Dict of predicted distribution parameters.
self.predicted_distributions = None # Predicted distribution for examples.
self.autoregressive_target = None # Autoregressive target predictions.
self.batch_losses = None # Loss for each predicted distribution in batch.
self.per_example_loss = None # Loss for each example in batch.
self.num_nonzero_weight_examples = None # Number of examples in batch.
self.total_loss = None # Overall loss for the batch.
self.global_step = None # Global step Tensor.
def causal_conv_layer(self, x, output_size, kernel_width, dilation_rate=1):
"""Applies a dialated causal convolution to the input.
Args:
x: tf.Tensor; Input tensor.
output_size: int; Number of output filters for the convolution.
kernel_width: int; Width of the 1D convolution window.
dilation_rate: int; Dilation rate of the layer.
Returns:
Resulting tf.Tensor after applying the convolution.
"""
causal_conv_op = tf.keras.layers.Conv1D(
output_size,
kernel_width,
padding="causal",
dilation_rate=dilation_rate,
name="causal_conv")
return causal_conv_op(x)
def conv_1x1_layer(self, x, output_size, activation=None):
"""Applies a 1x1 convolution to the input.
Args:
x: tf.Tensor; Input tensor.
output_size: int; Number of output filters for the 1x1 convolution.
activation: Activation function to apply (e.g. 'relu').
Returns:
Resulting tf.Tensor after applying the 1x1 convolution.
"""
conv_1x1_op = tf.keras.layers.Conv1D(
output_size, 1, activation=activation, name="conv1x1")
return conv_1x1_op(x)
def gated_residual_layer(self, x, dilation_rate):
"""Creates a gated, dilated convolutional layer with a residual connnection.
Args:
x: tf.Tensor; Input tensor
dilation_rate: int; Dilation rate of the layer.
Returns:
skip_connection: tf.Tensor; Skip connection to network_output layer.
residual_connection: tf.Tensor; Sum of learned residual and input tensor.
"""
with tf.variable_scope("filter"):
x_filter_conv = self.causal_conv_layer(x, x.shape[-1].value,
self.hparams.dilation_kernel_width,
dilation_rate)
cond_filter_conv = self.conv_1x1_layer(self.conditioning_stack,
x.shape[-1].value)
with tf.variable_scope("gate"):
x_gate_conv = self.causal_conv_layer(x, x.shape[-1].value,
self.hparams.dilation_kernel_width,
dilation_rate)
cond_gate_conv = self.conv_1x1_layer(self.conditioning_stack,
x.shape[-1].value)
gated_activation = (
tf.tanh(x_filter_conv + cond_filter_conv) *
tf.sigmoid(x_gate_conv + cond_gate_conv))
with tf.variable_scope("residual"):
residual = self.conv_1x1_layer(gated_activation, x.shape[-1].value)
with tf.variable_scope("skip"):
skip_connection = self.conv_1x1_layer(gated_activation,
self.hparams.skip_output_dim)
return skip_connection, x + residual
def build_network(self):
"""Builds WaveNet network.
This consists of:
1) An initial causal convolution,
2) The dialation stack, and
3) Summing of skip connections
The network output can then be used to predict various output distributions.
Inputs:
self.autoregressive_input
self.conditioning_stack
Outputs:
self.network_output; tf.Tensor
"""
skip_connections = []
x = _shift_right(self.autoregressive_input)
with tf.variable_scope("preprocess"):
x = self.causal_conv_layer(x, self.hparams.preprocess_output_size,
self.hparams.preprocess_kernel_width)
for i in range(self.hparams.num_residual_blocks):
with tf.variable_scope("block_{}".format(i)):
for dilation_rate in self.hparams.dilation_rates:
with tf.variable_scope("dilation_{}".format(dilation_rate)):
skip_connection, x = self.gated_residual_layer(x, dilation_rate)
skip_connections.append(skip_connection)
self.network_output = tf.add_n(skip_connections)
def dist_params_layer(self, x, outputs_size):
"""Converts x to the correct shape for populating a distribution object.
Args:
x: A Tensor of shape [batch_size, time_series_length, num_features].
outputs_size: The number of parameters needed to specify all the
distributions in the output. E.g. 5*3=15 to specify 5 distributions with
3 parameters each.
Returns:
The parameters of each distribution, a tensor of shape [batch_size,
time_series_length, outputs_size].
"""
with tf.variable_scope("dist_params"):
conv_outputs = self.conv_1x1_layer(x, outputs_size)
return conv_outputs
def build_predictions(self):
"""Predicts output distribution from network outputs.
Runs the model through:
1) ReLU
2) 1x1 convolution
3) ReLU
4) 1x1 convolution
The result of the last convolution is used as the parameters of the
specified output distribution (currently either Categorical or Normal).
Inputs:
self.network_outputs
Outputs:
self.dist_params
self.predicted_distributions
Raises:
ValueError: If distribution type is neither 'categorical' nor 'normal'.
"""
with tf.variable_scope("postprocess"):
network_output = tf.keras.activations.relu(self.network_output)
network_output = self.conv_1x1_layer(
network_output,
output_size=network_output.shape[-1].value,
activation="relu")
num_dists = self.autoregressive_input.shape[-1].value
if self.hparams.output_distribution.type == "categorical":
num_classes = self.hparams.output_distribution.num_classes
logits = self.dist_params_layer(network_output, num_dists * num_classes)
logits_shape = tf.concat(
[tf.shape(network_output)[:-1], [num_dists, num_classes]], 0)
logits = tf.reshape(logits, logits_shape)
dist = tfp.distributions.Categorical(logits=logits)
dist_params = {"logits": logits}
elif self.hparams.output_distribution.type == "normal":
loc_scale = self.dist_params_layer(network_output, num_dists * 2)
loc, scale = tf.split(loc_scale, 2, axis=-1)
# Ensure scale is positive.
scale = tf.nn.softplus(scale) + self.hparams.output_distribution.min_scale
dist = tfp.distributions.Normal(loc, scale)
dist_params = {"loc": loc, "scale": scale}
else:
raise ValueError("Unsupported distribution type {}".format(
self.hparams.output_distribution.type))
self.dist_params = dist_params
self.predicted_distributions = dist
def build_losses(self):
"""Builds the training losses.
Inputs:
self.predicted_distributions
Outputs:
self.batch_losses
self.total_loss
"""
autoregressive_target = self.autoregressive_input
# Quantize the target if the output distribution is categorical.
if self.hparams.output_distribution.type == "categorical":
min_val = self.hparams.output_distribution.min_quantization_value
max_val = self.hparams.output_distribution.max_quantization_value
num_classes = self.hparams.output_distribution.num_classes
clipped_target = tf.keras.backend.clip(autoregressive_target, min_val,
max_val)
quantized_target = tf.floor(
(clipped_target - min_val) / (max_val - min_val) * num_classes)
# Deal with the corner case where clipped_target equals max_val by mapping
# the label num_classes to num_classes - 1. Essentially, this makes the
# final quantized bucket a closed interval while all the other quantized
# buckets are half-open intervals.
quantized_target = tf.where(
quantized_target >= num_classes,
tf.ones_like(quantized_target) * (num_classes - 1), quantized_target)
autoregressive_target = quantized_target
log_prob = self.predicted_distributions.log_prob(autoregressive_target)
weights = self.weights
if weights is None:
weights = tf.ones_like(log_prob)
weights_dim = len(weights.shape)
per_example_weight = tf.reduce_sum(
weights, axis=list(range(1, weights_dim)))
per_example_indicator = tf.to_float(tf.greater(per_example_weight, 0))
num_examples = tf.reduce_sum(per_example_indicator)
batch_losses = -log_prob * weights
losses_ndims = batch_losses.shape.ndims
per_example_loss_sum = tf.reduce_sum(
batch_losses, axis=list(range(1, losses_ndims)))
per_example_loss = tf.where(per_example_weight > 0,
per_example_loss_sum / per_example_weight,
tf.zeros_like(per_example_weight))
total_loss = tf.reduce_sum(per_example_loss) / num_examples
self.autoregressive_target = autoregressive_target
self.batch_losses = batch_losses
self.per_example_loss = per_example_loss
self.num_nonzero_weight_examples = num_examples
self.total_loss = total_loss
def build(self):
"""Creates all ops for training, evaluation or inference."""
self.global_step = tf.train.get_or_create_global_step()
self.build_network()
self.build_predictions()
if self.mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]:
self.build_losses()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurations for model building, training and evaluation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def base():
"""Returns the base config for model building, training and evaluation."""
return {
# Hyperparameters for building and training the model.
"hparams": {
"batch_size": 64,
"dilation_kernel_width": 2,
"skip_output_dim": 10,
"preprocess_output_size": 3,
"preprocess_kernel_width": 10,
"num_residual_blocks": 4,
"dilation_rates": [1, 2, 4, 8, 16],
"output_distribution": {
"type": "normal",
"min_scale": 0.001
},
# Learning rate parameters.
"learning_rate": 1e-6,
"learning_rate_decay_steps": 0,
"learning_rate_decay_factor": 0,
"learning_rate_decay_staircase": True,
# Optimizer for training the model.
"optimizer": "adam",
# If not None, gradient norms will be clipped to this value.
"clip_gradient_norm": 1,
}
}
def categorical():
"""Returns a config for models with a categorical output distribution.
Input values will be clipped to {min,max}_value_for_quantization, then
linearly split into num_classes.
"""
config = base()
config["hparams"]["output_distribution"] = {
"type": "categorical",
"num_classes": 256,
"min_quantization_value": -1,
"max_quantization_value": 1
}
return config
def get_config(config_name):
"""Returns config correspnding to provided name."""
if config_name in ["base", "normal"]:
return base()
elif config_name == "categorical":
return categorical()
else:
raise ValueError("Unrecognized config name: {}".format(config_name))
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "base",
srcs = [
"base.py",
],
deps = [
"//astronet/ops:dataset_ops",
"//astronet/util:configdict",
],
)
py_test(
name = "base_test",
srcs = ["base_test.py"],
data = ["test_data/test-dataset.tfrecord"],
srcs_version = "PY2AND3",
deps = [":base"],
)
py_library(
name = "kepler_light_curves",
srcs = [
"kepler_light_curves.py",
],
deps = [
":base",
"//astronet/util:configdict",
],
)
py_library(
name = "synthetic_transits",
srcs = [
"synthetic_transits.py",
],
deps = [
":base",
":synthetic_transit_maker",
"//astronet/util:configdict",
],
)
py_library(
name = "synthetic_transit_maker",
srcs = [
"synthetic_transit_maker.py",
],
)
py_test(
name = "synthetic_transit_maker_test",
srcs = ["synthetic_transit_maker_test.py"],
srcs_version = "PY2AND3",
deps = [":synthetic_transit_maker"],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base dataset builder classes for AstroWaveNet input pipelines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
import tensorflow as tf
from astronet.util import configdict
from astronet.ops import dataset_ops
@six.add_metaclass(abc.ABCMeta)
class DatasetBuilder(object):
"""Base class for building a dataset input pipeline for AstroWaveNet."""
def __init__(self, config_overrides=None):
"""Initializes the dataset builder.
Args:
config_overrides: Dict or ConfigDict containing overrides to the default
configuration.
"""
self.config = configdict.ConfigDict(self.default_config())
if config_overrides is not None:
self.config.update(config_overrides)
@staticmethod
def default_config():
"""Returns the default configuration as a ConfigDict or Python dict."""
return {}
@abc.abstractmethod
def build(self, batch_size):
"""Builds the dataset input pipeline.
Args:
batch_size: The number of input examples in each batch.
Returns:
A tf.data.Dataset object.
"""
raise NotImplementedError
@six.add_metaclass(abc.ABCMeta)
class _ShardedDatasetBuilder(DatasetBuilder):
"""Abstract base class for a dataset consisting of sharded files."""
def __init__(self, file_pattern, mode, config_overrides=None, use_tpu=False):
"""Initializes the dataset builder.
Args:
file_pattern: File pattern matching input file shards, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
mode: A tf.estimator.ModeKeys.
config_overrides: Dict or ConfigDict containing overrides to the default
configuration.
use_tpu: Whether to build the dataset for TPU.
"""
super(_ShardedDatasetBuilder, self).__init__(config_overrides)
self.file_pattern = file_pattern
self.mode = mode
self.use_tpu = use_tpu
@staticmethod
def default_config():
config = super(_ShardedDatasetBuilder,
_ShardedDatasetBuilder).default_config()
config.update({
"max_length": 1024,
"shuffle_values_buffer": 1000,
"num_parallel_parser_calls": 4,
"batches_buffer_size": None, # Defaults to max(1, 256 / batch_size).
})
return config
@abc.abstractmethod
def file_reader(self):
"""Returns a function that reads a single sharded file."""
raise NotImplementedError
@abc.abstractmethod
def create_example_parser(self):
"""Returns a function that parses a single tf.Example proto."""
raise NotImplementedError
def _batch_and_pad(self, dataset, batch_size):
"""Combines elements into batches of the same length, padding if needed."""
if self.use_tpu:
padded_length = self.config.max_length
if not padded_length:
raise ValueError("config.max_length is required when using TPU")
# Pad with zeros up to padded_length. Note that this will pad the
# "weights" Tensor with zeros as well, which ensures that padded elements
# do not contribute to the loss.
padded_shapes = {}
for name, shape in dataset.output_shapes.iteritems():
shape.assert_is_compatible_with([None, None]) # Expect a 2D sequence.
dims = shape.as_list()
dims[0] = padded_length
shape = tf.TensorShape(dims)
shape.assert_is_fully_defined()
padded_shapes[name] = shape
else:
# Pad each batch up to the maximum size of each dimension in the batch.
padded_shapes = dataset.output_shapes
return dataset.padded_batch(batch_size, padded_shapes)
def build(self, batch_size):
"""Builds the dataset input pipeline.
Args:
batch_size:
Returns:
A tf.data.Dataset.
Raises:
ValueError: If no files match self.file_pattern.
"""
file_patterns = self.file_pattern.split(",")
filenames = []
for p in file_patterns:
matches = tf.gfile.Glob(p)
if not matches:
raise ValueError("Found no input files matching {}".format(p))
filenames.extend(matches)
tf.logging.info(
"Building input pipeline from %d files matching patterns: %s",
len(filenames), file_patterns)
is_training = self.mode == tf.estimator.ModeKeys.TRAIN
# Create a string dataset of filenames, and possibly shuffle.
filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
if is_training and len(filenames) > 1:
filename_dataset = filename_dataset.shuffle(len(filenames))
# Read serialized Example protos.
dataset = filename_dataset.apply(
tf.contrib.data.parallel_interleave(
self.file_reader(), cycle_length=8, block_length=8, sloppy=True))
if is_training:
# Shuffle and repeat. Note that shuffle() is before repeat(), so elements
# are shuffled among each epoch of data, and not between epochs of data.
if self.config.shuffle_values_buffer > 0:
dataset = dataset.shuffle(self.config.shuffle_values_buffer)
dataset = dataset.repeat()
# Map the parser over the dataset.
dataset = dataset.map(
self.create_example_parser(),
num_parallel_calls=self.config.num_parallel_parser_calls)
def _prepare_wavenet_inputs(features):
"""Validates features, and clips lengths and adds weights if needed."""
# Validate feature names.
required_features = {"autoregressive_input", "conditioning_stack"}
allowed_features = required_features | {"weights"}
feature_names = features.keys()
if not required_features.issubset(feature_names):
raise ValueError("Features must contain all of: {}. Got: {}".format(
required_features, feature_names))
if not allowed_features.issuperset(feature_names):
raise ValueError("Features can only contain: {}. Got: {}".format(
allowed_features, feature_names))
output = {}
for name, value in features.items():
# Validate shapes. The output dimension is [num_samples, dim].
ndims = len(value.shape)
if ndims == 1:
# Add an extra dimension: [num_samples] -> [num_samples, 1].
value = tf.expand_dims(value, -1)
elif ndims != 2:
raise ValueError(
"Features should be 1D or 2D sequences. Got '{}' = {}".format(
name, value))
if self.config.max_length:
value = value[:self.config.max_length]
output[name] = value
if "weights" not in output:
output["weights"] = tf.ones_like(output["autoregressive_input"])
return output
dataset = dataset.map(_prepare_wavenet_inputs)
# Batch results by up to batch_size.
dataset = self._batch_and_pad(dataset, batch_size)
if is_training:
# The dataset repeats infinitely before batching, so each batch has the
# maximum number of elements.
dataset = dataset_ops.set_batch_size(dataset, batch_size)
elif self.use_tpu and self.mode == tf.estimator.ModeKeys.EVAL:
# Pad to ensure that each batch has the same number of elements.
dataset = dataset_ops.pad_dataset_to_batch_size(dataset, batch_size)
# Prefetch batches.
buffer_size = (
self.config.batches_buffer_size or max(1, int(256 / batch_size)))
dataset = dataset.prefetch(buffer_size)
return dataset
def tfrecord_reader(filename):
"""Returns a tf.data.Dataset that reads a single TFRecord file shard."""
return tf.data.TFRecordDataset(filename, buffer_size=16 * 1000 * 1000)
class TFRecordDataset(_ShardedDatasetBuilder):
"""Builder for a dataset consisting of TFRecord files."""
def file_reader(self):
"""Returns a function that reads a single file shard."""
return tfrecord_reader
此差异已折叠。
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Kepler light curve inputs to the AstroWaveNet model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from astrowavenet.data import base
COND_INPUT_KEY = "mask"
AR_INPUT_KEY = "flux"
class KeplerLightCurves(base.TFRecordDataset):
"""Kepler light curve inputs to the AstroWaveNet model."""
def create_example_parser(self):
def _example_parser(serialized):
"""Parses a single tf.Example proto."""
features = tf.parse_single_example(
serialized,
features={
AR_INPUT_KEY: tf.VarLenFeature(tf.float32),
COND_INPUT_KEY: tf.VarLenFeature(tf.int64),
})
# Extract values from SparseTensor objects.
autoregressive_input = features[AR_INPUT_KEY].values
conditioning_stack = tf.to_float(features[COND_INPUT_KEY].values)
return {
"autoregressive_input": autoregressive_input,
"conditioning_stack": conditioning_stack,
}
return _example_parser
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generates synthetic light curves with periodic transit-like dips.
See class docstring below for more information.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
class SyntheticTransitMaker(object):
"""Generates synthetic light curves with periodic transit-like dips.
These light curves are generated by thresholding noisy sine waves. Each time
random_light_curve is called, a thresholded sine wave is generated by sampling
parameters uniformly from the ranges specified below.
Attributes:
period_range: A tuple of positive values specifying the range of periods the
sine waves may take.
amplitude_range: A tuple of positive values specifying the range of
amplitudes the sine waves may take.
threshold_ratio_range: A tuple of values in [0, 1) specifying the range of
thresholds as a ratio of the sine wave amplitude.
phase_range: Tuple of values specifying the range of phases the sine wave
may take as a ratio of the sampled period. E.g. a sampled phase of 0.5
would translate the sine wave by half of the period. The most common
reason to override this would be to generate light curves
deterministically (with e.g. (0,0)).
noise_sd_range: A tuple of values in [0, 1) specifying the range of standard
deviations for the Gaussian noise applied to the sine wave.
"""
def __init__(self,
period_range=(0.5, 4),
amplitude_range=(1, 1),
threshold_ratio_range=(0, 0.99),
phase_range=(0, 1),
noise_sd_range=(0.1, 0.1)):
if threshold_ratio_range[0] < 0 or threshold_ratio_range[1] >= 1:
raise ValueError("Threshold ratio range must be in [0, 1). Got: {}."
.format(threshold_ratio_range))
if amplitude_range[0] <= 0:
raise ValueError(
"Amplitude range must only contain positive numbers. Got: {}.".format(
amplitude_range))
if period_range[0] <= 0:
raise ValueError(
"Period range must only contain positive numbers. Got: {}.".format(
period_range))
if noise_sd_range[0] < 0:
raise ValueError(
"Noise standard deviation range must be nonnegative. Got: {}.".format(
noise_sd_range))
for (start, end), name in [(period_range, "period"),
(amplitude_range, "amplitude"),
(threshold_ratio_range, "threshold ratio"),
(phase_range, "phase range"),
(noise_sd_range, "noise standard deviation")]:
if end < start:
raise ValueError(
"End of {} range may not be less than start. Got: ({}, {})".format(
name, start, end))
self.period_range = period_range
self.amplitude_range = amplitude_range
self.threshold_ratio_range = threshold_ratio_range
self.phase_range = phase_range
self.noise_sd_range = noise_sd_range
def random_light_curve(self, time, mask_prob=0):
"""Samples parameters and generates a light curve.
Args:
time: np.array, x-values to sample from the thresholded sine wave.
mask_prob: value in [0,1], probability an individual datapoint is set to
zero
Returns:
flux: np.array, values of the masked sampled light curve corresponding to
the provided time array.
mask: np.array of ones and zeros, with zeros indicating masking at the
respective position on the flux array.
"""
period = np.random.uniform(*self.period_range)
phase = np.random.uniform(*self.phase_range) * period
amplitude = np.random.uniform(*self.amplitude_range)
threshold = np.random.uniform(*self.threshold_ratio_range) * amplitude
sin_wave = np.sin(time / period - phase) * amplitude
flux = np.minimum(sin_wave, -threshold) + threshold
noise_sd = np.random.uniform(*self.noise_sd_range)
noise = np.random.normal(scale=noise_sd, size=(len(time),))
flux += noise
# Array of ones and zeros, where zeros indicate masking.
mask = np.random.random(len(time)) > mask_prob
mask = mask.astype(np.float)
return flux * mask, mask
def random_light_curve_generator(self, time, mask_prob=0):
"""Returns a generator function yielding random light curves.
Args:
time: An np.array of x-values to sample from the thresholded sine wave.
mask_prob: Value in [0,1], probability an individual datapoint is set to
zero.
Returns:
A generator yielding random light curves.
"""
def generator_fn():
while True:
yield self.random_light_curve(time, mask_prob)
return generator_fn
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for synthetic_transit_maker."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from astrowavenet.data import synthetic_transit_maker
class SyntheticTransitMakerTest(absltest.TestCase):
def testBadRangesRaiseExceptions(self):
# Period range cannot contain negative values.
with self.assertRaisesRegexp(ValueError, "Period"):
synthetic_transit_maker.SyntheticTransitMaker(period_range=(-1, 10))
# Amplitude range cannot contain negative values.
with self.assertRaisesRegexp(ValueError, "Amplitude"):
synthetic_transit_maker.SyntheticTransitMaker(amplitude_range=(-10, -1))
# Threshold ratio range must be contained in the half-open interval [0, 1).
with self.assertRaisesRegexp(ValueError, "Threshold ratio"):
synthetic_transit_maker.SyntheticTransitMaker(
threshold_ratio_range=(0, 1))
# Noise standard deviation range must only contain nonnegative values.
with self.assertRaisesRegexp(ValueError, "Noise standard deviation"):
synthetic_transit_maker.SyntheticTransitMaker(noise_sd_range=(-1, 1))
# End of range may not be less than start.
invalid_range = (0.2, 0.1)
range_args = [
"period_range", "threshold_ratio_range", "amplitude_range",
"noise_sd_range", "phase_range"
]
for range_arg in range_args:
with self.assertRaisesRegexp(ValueError, "may not be less"):
synthetic_transit_maker.SyntheticTransitMaker(
**{range_arg: invalid_range})
def testStochasticLightCurveGeneration(self):
transit_maker = synthetic_transit_maker.SyntheticTransitMaker()
time = np.arange(100)
flux, mask = transit_maker.random_light_curve(time, mask_prob=0.4)
self.assertEqual(len(flux), 100)
self.assertEqual(len(mask), 100)
def testDeterministicLightCurveGeneration(self):
gold_flux = np.array([
0., 0., 0., 0., 0., 0., 0., -0.85099258, -2.04776251, -2.65829632,
-2.53014378, -1.69530454, -0.36223792, 0., 0., 0., 0., 0., 0.,
-0.2110405, -1.57757635, -2.47528153, -2.67999913, -2.14061117,
-0.9918028, 0., 0., 0., 0., 0., 0., 0., -1.01475559, -2.15534176,
-2.68282928, -2.46550457, -1.55763357, -0.18591162, 0., 0., 0., 0., 0.,
0., -0.3870683, -1.71426199, -2.53849461, -2.65395535, -2.03181367,
-0.82741829, 0., 0., 0., 0., 0., 0., 0., -1.17380391, -2.2541162,
-2.69666588, -2.39094831, -1.41330116, -0.00784284, 0., 0., 0., 0., 0.,
0., -0.56063229, -1.84372452, -2.59152891, -2.61731875, -1.91465433,
-0.65899089, 0., 0., 0., 0., 0., 0., 0., -1.3275672, -2.34373163,
-2.69975648, -2.30674237, -1.26282489, 0., 0., 0., 0., 0., 0., 0.,
-0.73111006, -1.9654997, -2.63419424, -2.5702207, -1.78955328,
-0.48712456
])
# Use ranges containing one value for determinism.
transit_maker = synthetic_transit_maker.SyntheticTransitMaker(
period_range=(2, 2),
amplitude_range=(3, 3),
threshold_ratio_range=(.1, .1),
phase_range=(0, 0),
noise_sd_range=(0, 0))
time = np.linspace(0, 100, 100)
flux, mask = transit_maker.random_light_curve(time)
np.testing.assert_array_almost_equal(flux, gold_flux)
np.testing.assert_array_almost_equal(mask, np.ones(100))
def testRandomLightCurveGenerator(self):
transit_maker = synthetic_transit_maker.SyntheticTransitMaker()
time = np.linspace(0, 100, 100)
generator = transit_maker.random_light_curve_generator(
time, mask_prob=0.3)()
for _ in range(5):
flux, mask = next(generator)
self.assertEqual(len(flux), 100)
self.assertEqual(len(mask), 100)
if __name__ == "__main__":
absltest.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Synthetic transit inputs to the AstroWaveNet model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from astronet.util import configdict
from astrowavenet.data import base
from astrowavenet.data import synthetic_transit_maker
def _prepare_wavenet_inputs(light_curve, mask):
"""Gathers synthetic transits into the format expected by AstroWaveNet."""
return {
"autoregressive_input": tf.expand_dims(light_curve, -1),
"conditioning_stack": tf.expand_dims(mask, -1),
}
class SyntheticTransits(base.DatasetBuilder):
"""Synthetic transit inputs to the AstroWaveNet model."""
@staticmethod
def default_config():
return configdict.ConfigDict({
"period_range": (0.5, 4),
"amplitude_range": (1, 1),
"threshold_ratio_range": (0, 0.99),
"phase_range": (0, 1),
"noise_sd_range": (0.1, 0.1),
"mask_probability": 0.1,
"light_curve_time_range": (0, 100),
"light_curve_num_points": 1000
})
def build(self, batch_size):
transit_maker = synthetic_transit_maker.SyntheticTransitMaker(
period_range=self.config.period_range,
amplitude_range=self.config.amplitude_range,
threshold_ratio_range=self.config.threshold_ratio_range,
phase_range=self.config.phase_range,
noise_sd_range=self.config.noise_sd_range)
t_start, t_end = self.config.light_curve_time_range
time = np.linspace(t_start, t_end, self.config.light_curve_num_points)
dataset = tf.data.Dataset.from_generator(
transit_maker.random_light_curve_generator(
time, mask_prob=self.config.mask_probability),
output_types=(tf.float32, tf.float32),
output_shapes=(tf.TensorShape((self.config.light_curve_num_points,)),
tf.TensorShape((self.config.light_curve_num_points,))))
dataset = dataset.map(_prepare_wavenet_inputs)
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(-1)
return dataset
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for training and evaluating AstroWaveNet models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os.path
from absl import flags
import tensorflow as tf
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astrowavenet import astrowavenet_model
from astrowavenet import configurations
from astrowavenet.data import kepler_light_curves
from astrowavenet.data import synthetic_transits
from astrowavenet.util import estimator_util
FLAGS = flags.FLAGS
flags.DEFINE_enum("dataset", None,
["synthetic_transits", "kepler_light_curves"],
"Dataset for training and/or evaluation.")
flags.DEFINE_string("model_dir", None, "Base output directory.")
flags.DEFINE_string(
"train_files", None,
"Comma-separated list of file patterns matching the TFRecord files in the "
"training dataset.")
flags.DEFINE_string(
"eval_files", None,
"Comma-separated list of file patterns matching the TFRecord files in the "
"evaluation dataset.")
flags.DEFINE_string("config_name", "base",
"Name of the AstroWaveNet configuration.")
flags.DEFINE_string(
"config_overrides", "{}",
"JSON string or JSON file containing overrides to the base configuration.")
flags.DEFINE_enum("schedule", None,
["train", "train_and_eval", "continuous_eval"],
"Schedule for running the model.")
flags.DEFINE_string("eval_name", "val", "Name of the evaluation task.")
flags.DEFINE_integer("train_steps", None, "Total number of steps for training.")
flags.DEFINE_integer("eval_steps", None, "Number of steps for each evaluation.")
flags.DEFINE_integer(
"local_eval_frequency", 1000,
"The number of training steps in between evaluation runs. Only applies "
"when schedule == 'train_and_eval'.")
flags.DEFINE_integer("save_summary_steps", None,
"The frequency at which to save model summaries.")
flags.DEFINE_integer("save_checkpoints_steps", None,
"The frequency at which to save model checkpoints.")
flags.DEFINE_integer("save_checkpoints_secs", None,
"The frequency at which to save model checkpoints.")
flags.DEFINE_integer("keep_checkpoint_max", 1,
"The maximum number of model checkpoints to keep.")
# ------------------------------------------------------------------------------
# TPU-only flags
# ------------------------------------------------------------------------------
flags.DEFINE_boolean("use_tpu", False, "Whether to execute on TPU.")
flags.DEFINE_string("master", None, "Address of the TensorFlow TPU master.")
flags.DEFINE_integer("tpu_num_shards", 8, "Number of TPU shards.")
flags.DEFINE_integer("tpu_iterations_per_loop", 1000,
"Number of iterations per TPU training loop.")
flags.DEFINE_integer(
"eval_batch_size", None,
"Batch size for TPU evaluation. Defaults to the training batch size.")
def _create_run_config():
"""Creates a TPU RunConfig if FLAGS.use_tpu is True, else a RunConfig."""
session_config = tf.ConfigProto(allow_soft_placement=True)
run_config_kwargs = {
"save_summary_steps": FLAGS.save_summary_steps,
"save_checkpoints_steps": FLAGS.save_checkpoints_steps,
"save_checkpoints_secs": FLAGS.save_checkpoints_secs,
"session_config": session_config,
"keep_checkpoint_max": FLAGS.keep_checkpoint_max
}
if FLAGS.use_tpu:
if not FLAGS.master:
raise ValueError("FLAGS.master must be set for TPUEstimator.")
tpu_config = tf.contrib.tpu.TPUConfig(
iterations_per_loop=FLAGS.tpu_iterations_per_loop,
num_shards=FLAGS.tpu_num_shards,
per_host_input_for_training=(FLAGS.tpu_num_shards <= 8))
run_config = tf.contrib.tpu.RunConfig(
tpu_config=tpu_config, master=FLAGS.master, **run_config_kwargs)
else:
if FLAGS.master:
raise ValueError("FLAGS.master should only be set for TPUEstimator.")
run_config = tf.estimator.RunConfig(**run_config_kwargs)
return run_config
def _get_file_pattern(mode):
"""Gets the value of the file pattern flag for the specified mode."""
flag_name = ("train_files"
if mode == tf.estimator.ModeKeys.TRAIN else "eval_files")
file_pattern = FLAGS[flag_name].value
if file_pattern is None:
raise ValueError("--{} is required for mode '{}'".format(flag_name, mode))
return file_pattern
def _create_dataset_builder(mode, config_overrides=None):
"""Creates a dataset builder for the input pipeline."""
if FLAGS.dataset == "synthetic_transits":
return synthetic_transits.SyntheticTransits(config_overrides)
file_pattern = _get_file_pattern(mode)
if FLAGS.dataset == "kepler_light_curves":
builder_class = kepler_light_curves.KeplerLightCurves
else:
raise ValueError("Unsupported dataset: {}".format(FLAGS.dataset))
return builder_class(
file_pattern,
mode,
config_overrides=config_overrides,
use_tpu=FLAGS.use_tpu)
def _create_input_fn(mode, config_overrides=None):
"""Creates an Estimator input_fn."""
builder = _create_dataset_builder(mode, config_overrides)
tf.logging.info("Dataset config for mode '%s': %s", mode,
config_util.to_json(builder.config))
return estimator_util.create_input_fn(builder)
def _create_eval_args(config_overrides=None):
"""Builds eval_args for estimator_runner.evaluate()."""
if FLAGS.dataset == "synthetic_transits" and not FLAGS.eval_steps:
raise ValueError("Dataset '{}' requires --eval_steps for evaluation".format(
FLAGS.dataset))
input_fn = _create_input_fn(tf.estimator.ModeKeys.EVAL, config_overrides)
return {FLAGS.eval_name: (input_fn, FLAGS.eval_steps)}
def main(argv):
del argv # Unused.
config = configdict.ConfigDict(configurations.get_config(FLAGS.config_name))
config_overrides = json.loads(FLAGS.config_overrides)
for key in config_overrides:
if key not in ["dataset", "hparams"]:
raise ValueError("Unrecognized config override: {}".format(key))
config.hparams.update(config_overrides.get("hparams", {}))
# Log configs.
configs_json = [
("config_overrides", config_util.to_json(config_overrides)),
("config", config_util.to_json(config)),
]
for config_name, config_json in configs_json:
tf.logging.info("%s: %s", config_name, config_json)
# Create the estimator.
run_config = _create_run_config()
estimator = estimator_util.create_estimator(
astrowavenet_model.AstroWaveNet, config.hparams, run_config,
FLAGS.model_dir, FLAGS.eval_batch_size)
if FLAGS.schedule in ["train", "train_and_eval"]:
# Save configs.
tf.gfile.MakeDirs(FLAGS.model_dir)
for config_name, config_json in configs_json:
filename = os.path.join(FLAGS.model_dir, "{}.json".format(config_name))
with tf.gfile.Open(filename, "w") as f:
f.write(config_json)
train_input_fn = _create_input_fn(tf.estimator.ModeKeys.TRAIN,
config_overrides.get("dataset"))
train_hooks = []
if FLAGS.schedule == "train":
estimator.train(
train_input_fn, hooks=train_hooks, max_steps=FLAGS.train_steps)
else:
assert FLAGS.schedule == "train_and_eval"
eval_args = _create_eval_args(config_overrides.get("dataset"))
for _ in estimator_runner.continuous_train_and_eval(
estimator=estimator,
train_input_fn=train_input_fn,
eval_args=eval_args,
local_eval_frequency=FLAGS.local_eval_frequency,
train_hooks=train_hooks,
train_steps=FLAGS.train_steps):
# continuous_train_and_eval() yields evaluation metrics after each
# FLAGS.local_eval_frequency. It also saves and logs them, so we don't
# do anything here.
pass
else:
assert FLAGS.schedule == "continuous_eval"
eval_args = _create_eval_args(config_overrides.get("dataset"))
for _ in estimator_runner.continuous_eval(
estimator=estimator, eval_args=eval_args,
train_steps=FLAGS.train_steps):
# continuous_train_and_eval() yields evaluation metrics after each
# checkpoint. It also saves and logs them, so we don't do anything here.
pass
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
flags.mark_flags_as_required(["dataset", "model_dir", "schedule"])
def _validate_schedule(flag_values):
"""Validates the --schedule flag and the flags it interacts with."""
schedule = flag_values["schedule"]
save_checkpoints_steps = flag_values["save_checkpoints_steps"]
save_checkpoints_secs = flag_values["save_checkpoints_secs"]
if schedule in ["train", "train_and_eval"]:
if not (save_checkpoints_steps or save_checkpoints_secs):
raise flags.ValidationError(
"--schedule='%s' requires --save_checkpoints_steps or "
"--save_checkpoints_secs." % schedule)
return True
flags.register_multi_flags_validator(
["schedule", "save_checkpoints_steps", "save_checkpoints_secs"],
_validate_schedule)
tf.app.run()
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "estimator_util",
srcs = ["estimator_util.py"],
srcs_version = "PY2AND3",
deps = ["//astronet/ops:training"],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for creating a TensorFlow Estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import tensorflow as tf
from astronet.ops import training
class _InputFn(object):
"""Class that acts as a callable input function for Estimator train / eval."""
def __init__(self, dataset_builder):
"""Initializes the input function.
Args:
dataset_builder: Instance of DatasetBuilder.
"""
self._builder = dataset_builder
def __call__(self, params):
"""Builds the input pipeline."""
return self._builder.build(batch_size=params["batch_size"])
def create_input_fn(dataset_builder):
"""Creates an input_fn that that builds an input pipeline.
Args:
dataset_builder: Instance of DatasetBuilder.
Returns:
A callable that builds an input pipeline and returns a tf.data.Dataset
object.
"""
return _InputFn(dataset_builder)
class _ModelFn(object):
"""Class that acts as a callable model function for Estimator train / eval."""
def __init__(self, model_class, hparams, use_tpu=False):
"""Initializes the model function.
Args:
model_class: Model class.
hparams: A HParams object containing hyperparameters for building and
training the model.
use_tpu: If True, a TPUEstimator will be returned. Otherwise an Estimator
will be returned.
"""
self._model_class = model_class
self._base_hparams = hparams
self._use_tpu = use_tpu
def __call__(self, features, mode, params):
"""Builds the model and returns an EstimatorSpec or TPUEstimatorSpec."""
hparams = copy.deepcopy(self._base_hparams)
if "batch_size" in params:
hparams.batch_size = params["batch_size"]
model = self._model_class(features, hparams, mode)
model.build()
# Possibly create train_op.
use_tpu = self._use_tpu
train_op = None
if mode == tf.estimator.ModeKeys.TRAIN:
learning_rate = training.create_learning_rate(hparams, model.global_step)
optimizer = training.create_optimizer(hparams, learning_rate, use_tpu)
train_op = training.create_train_op(model, optimizer)
if use_tpu:
estimator = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode, loss=model.total_loss, train_op=train_op)
else:
estimator = tf.estimator.EstimatorSpec(
mode=mode, loss=model.total_loss, train_op=train_op)
return estimator
def create_model_fn(model_class, hparams, use_tpu=False):
"""Wraps model_class as an Estimator or TPUEstimator model_fn.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
use_tpu: If True, a TPUEstimator model_fn is returned. Otherwise an
Estimator model_fn is returned.
Returns:
model_fn: A callable that constructs the model and returns a
TPUEstimatorSpec if use_tpu is True, otherwise an EstimatorSpec.
"""
return _ModelFn(model_class, hparams, use_tpu)
def create_estimator(model_class,
hparams,
run_config=None,
model_dir=None,
eval_batch_size=None):
"""Wraps model_class as an Estimator or TPUEstimator.
If run_config is None or a tf.estimator.RunConfig, an Estimator is returned.
If run_config is a tf.contrib.tpu.RunConfig, a TPUEstimator is returned.
Args:
model_class: AstroWaveNet or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
run_config: Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig.
model_dir: Optional directory for saving the model. If not passed
explicitly, it must be specified in run_config.
eval_batch_size: Optional batch size for evaluation on TPU. Only applicable
if run_config is a tf.contrib.tpu.RunConfig. Defaults to
hparams.batch_size.
Returns:
An Estimator object if run_config is None or a tf.estimator.RunConfig, or a
TPUEstimator object if run_config is a tf.contrib.tpu.RunConfig.
Raises:
ValueError:
If model_dir is not passed explicitly or in run_config.model_dir, or if
eval_batch_size is specified and run_config is not a
tf.contrib.tpu.RunConfig.
"""
if run_config is None:
run_config = tf.estimator.RunConfig()
else:
run_config = copy.deepcopy(run_config)
if not model_dir and not run_config.model_dir:
raise ValueError(
"model_dir must be passed explicitly or specified in run_config")
use_tpu = isinstance(run_config, tf.contrib.tpu.RunConfig)
model_fn = create_model_fn(model_class, hparams, use_tpu)
if use_tpu:
eval_batch_size = eval_batch_size or hparams.batch_size
estimator = tf.contrib.tpu.TPUEstimator(
model_fn=model_fn,
model_dir=model_dir,
config=run_config,
train_batch_size=hparams.batch_size,
eval_batch_size=eval_batch_size)
else:
if eval_batch_size is not None:
raise ValueError("eval_batch_size can only be specified for TPU.")
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=model_dir,
config=run_config,
params={"batch_size": hparams.batch_size})
return estimator
......@@ -6,6 +6,7 @@ py_library(
name = "kepler_io",
srcs = ["kepler_io.py"],
srcs_version = "PY2AND3",
deps = [":util"],
)
py_test(
......
......@@ -44,5 +44,5 @@ class MedianFilterTest(absltest.TestCase):
np.testing.assert_almost_equal(result, expected)
if __name__ == '__main__':
if __name__ == "__main__":
absltest.main()
......@@ -66,5 +66,5 @@ class PhaseFoldAndSortLightCurveTest(absltest.TestCase):
np.testing.assert_almost_equal(folded_flux, expected_flux)
if __name__ == '__main__':
if __name__ == "__main__":
absltest.main()
......@@ -24,7 +24,8 @@ def ValueErrorOnFalse(ok, *output_args):
"""Raises ValueError if not ok, otherwise returns the output arguments."""
n_outputs = len(output_args)
if n_outputs < 2:
raise ValueError("Expected 2 or more output_args. Got: %d" % n_outputs)
raise ValueError(
"Expected 2 or more output_args. Got: {}".format(n_outputs))
if not ok:
error = output_args[-1]
......
......@@ -76,5 +76,5 @@ class ViewGeneratorTest(absltest.TestCase):
np.testing.assert_almost_equal(result, expected)
if __name__ == '__main__':
if __name__ == "__main__":
absltest.main()
......@@ -23,10 +23,9 @@ import os.path
from astropy.io import fits
import numpy as np
from light_curve_util import util
from tensorflow import gfile
LONG_CADENCE_TIME_DELTA_DAYS = 0.02043422 # Approximately 29.4 minutes.
# Quarter index to filename prefix for long cadence Kepler data.
# Reference: https://archive.stsci.edu/kepler/software/get_kepler.py
LONG_CADENCE_QUARTER_PREFIXES = {
......@@ -73,6 +72,14 @@ SHORT_CADENCE_QUARTER_PREFIXES = {
17: ["2013121191144", "2013131215648"]
}
# Quarter order for different scrambling procedures.
# Page 9: https://ntrs.nasa.gov/archive/nasa/casi.ntrs.nasa.gov/20170009549.pdf.
SIMULATED_DATA_SCRAMBLE_ORDERS = {
"SCR1": [0, 13, 14, 15, 16, 9, 10, 11, 12, 5, 6, 7, 8, 1, 2, 3, 4, 17],
"SCR2": [0, 1, 2, 3, 4, 13, 14, 15, 16, 9, 10, 11, 12, 5, 6, 7, 8, 17],
"SCR3": [0, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 17],
}
def kepler_filenames(base_dir,
kep_id,
......@@ -98,21 +105,21 @@ def kepler_filenames(base_dir,
Args:
base_dir: Base directory containing Kepler data.
kep_id: Id of the Kepler target star. May be an int or a possibly zero-
padded string.
padded string.
long_cadence: Whether to read a long cadence (~29.4 min / measurement) light
curve as opposed to a short cadence (~1 min / measurement) light curve.
curve as opposed to a short cadence (~1 min / measurement) light curve.
quarters: Optional list of integers in [0, 17]; the quarters of the Kepler
mission to return.
mission to return.
injected_group: Optional string indicating injected light curves. One of
"inj1", "inj2", "inj3".
"inj1", "inj2", "inj3".
check_existence: If True, only return filenames corresponding to files that
exist (not all stars have data for all quarters).
exist (not all stars have data for all quarters).
Returns:
A list of filenames.
"""
# Pad the Kepler id with zeros to length 9.
kep_id = "%.9d" % int(kep_id)
kep_id = "{:09d}".format(int(kep_id))
quarter_prefixes, cadence_suffix = ((LONG_CADENCE_QUARTER_PREFIXES, "llc")
if long_cadence else
......@@ -128,12 +135,11 @@ def kepler_filenames(base_dir,
for quarter in quarters:
for quarter_prefix in quarter_prefixes[quarter]:
if injected_group:
base_name = "kplr%s-%s_INJECTED-%s_%s.fits" % (kep_id, quarter_prefix,
injected_group,
cadence_suffix)
base_name = "kplr{}-{}_INJECTED-{}_{}.fits".format(
kep_id, quarter_prefix, injected_group, cadence_suffix)
else:
base_name = "kplr%s-%s_%s.fits" % (kep_id, quarter_prefix,
cadence_suffix)
base_name = "kplr{}-{}_{}.fits".format(kep_id, quarter_prefix,
cadence_suffix)
filename = os.path.join(base_dir, base_name)
# Not all stars have data for all quarters.
if not check_existence or gfile.Exists(filename):
......@@ -142,40 +148,86 @@ def kepler_filenames(base_dir,
return filenames
def scramble_light_curve(all_time, all_flux, all_quarters, scramble_type):
"""Scrambles a light curve according to a given scrambling procedure.
Args:
all_time: List holding arrays of time values, each containing a quarter of
time data.
all_flux: List holding arrays of flux values, each containing a quarter of
flux data.
all_quarters: List of integers specifying which quarters are present in
the light curve (max is 18: Q0...Q17).
scramble_type: String specifying the scramble order, one of {'SCR1', 'SCR2',
'SCR3'}.
Returns:
scr_flux: Scrambled flux values; the same list as the input flux in another
order.
scr_time: Time values, re-partitioned to match sizes of the scr_flux lists.
"""
order = SIMULATED_DATA_SCRAMBLE_ORDERS[scramble_type]
scr_flux = []
for quarter in order:
# Ignore missing quarters in the scramble order.
if quarter in all_quarters:
scr_flux.append(all_flux[all_quarters.index(quarter)])
scr_time = util.reshard_arrays(all_time, scr_flux)
return scr_time, scr_flux
def read_kepler_light_curve(filenames,
light_curve_extension="LIGHTCURVE",
invert=False):
scramble_type=None,
interpolate_missing_time=False):
"""Reads time and flux measurements for a Kepler target star.
Args:
filenames: A list of .fits files containing time and flux measurements.
light_curve_extension: Name of the HDU 1 extension containing light curves.
invert: Whether to invert the flux measurements by multiplying by -1.
scramble_type: What scrambling procedure to use: 'SCR1', 'SCR2', or 'SCR3'
(pg 9: https://exoplanetarchive.ipac.caltech.edu/docs/KSCI-19114-002.pdf).
interpolate_missing_time: Whether to interpolate missing (NaN) time values.
This should only affect the output if scramble_type is specified (NaN time
values typically come with NaN flux values, which are removed anyway, but
scrambing decouples NaN time values from NaN flux values).
Returns:
all_time: A list of numpy arrays; the time values of the light curve.
all_flux: A list of numpy arrays corresponding to the time arrays in
all_time.
all_flux: A list of numpy arrays; the flux values of the light curve.
"""
all_time = []
all_flux = []
all_quarters = []
for filename in filenames:
with fits.open(gfile.Open(filename, "rb")) as hdu_list:
quarter = hdu_list["PRIMARY"].header["QUARTER"]
light_curve = hdu_list[light_curve_extension].data
time = light_curve.TIME
flux = light_curve.PDCSAP_FLUX
# Remove NaN flux values.
valid_indices = np.where(np.isfinite(flux))
time = time[valid_indices]
flux = flux[valid_indices]
time = light_curve.TIME
flux = light_curve.PDCSAP_FLUX
if not time.size:
continue # No data.
# Possibly interpolate missing time values.
if interpolate_missing_time:
time = util.interpolate_missing_time(time, light_curve.CADENCENO)
all_time.append(time)
all_flux.append(flux)
all_quarters.append(quarter)
if invert:
flux *= -1
if scramble_type:
all_time, all_flux = scramble_light_curve(all_time, all_flux, all_quarters,
scramble_type)
if time.size:
all_time.append(time)
all_flux.append(flux)
# Remove timestamps with NaN time or flux values.
for i, (time, flux) in enumerate(zip(all_time, all_flux)):
flux_and_time_finite = np.logical_and(np.isfinite(flux), np.isfinite(time))
all_time[i] = time[flux_and_time_finite]
all_flux[i] = flux[flux_and_time_finite]
return all_time, all_flux
......@@ -19,8 +19,10 @@ from __future__ import division
from __future__ import print_function
import os.path
from absl import flags
from absl.testing import absltest
import numpy as np
from light_curve_util import kepler_io
......@@ -34,6 +36,26 @@ class KeplerIoTest(absltest.TestCase):
def setUp(self):
self.data_dir = os.path.join(FLAGS.test_srcdir, _DATA_DIR)
def testScrambleLightCurve(self):
all_flux = [[11, 12], [21], [np.nan, np.nan, 33], [41, 42]]
all_time = [[101, 102], [201], [301, 302, 303], [401, 402]]
all_quarters = [3, 4, 7, 14]
scramble_type = "SCR1" # New quarters order will be [14,7,3,4].
scr_time, scr_flux = kepler_io.scramble_light_curve(
all_time, all_flux, all_quarters, scramble_type)
# NaNs are not removed in this function.
gold_flux = [[41, 42], [np.nan, np.nan, 33], [11, 12], [21]]
gold_time = [[101, 102], [201, 301, 302], [303, 401], [402]]
self.assertEqual(len(gold_flux), len(scr_flux))
self.assertEqual(len(gold_time), len(scr_time))
for i in range(len(gold_flux)):
np.testing.assert_array_equal(gold_flux[i], scr_flux[i])
np.testing.assert_array_equal(gold_time[i], scr_time[i])
def testKeplerFilenames(self):
# All quarters.
filenames = kepler_io.kepler_filenames(
......@@ -100,15 +122,17 @@ class KeplerIoTest(absltest.TestCase):
filenames = kepler_io.kepler_filenames(
self.data_dir, 11442793, check_existence=True)
expected_filenames = [
os.path.join(self.data_dir, "0114/011442793/kplr011442793-%s_llc.fits")
% q for q in ["2009350155506", "2010009091648", "2010174085026"]
os.path.join(self.data_dir,
"0114/011442793/kplr011442793-{}_llc.fits".format(q))
for q in ["2009350155506", "2010009091648", "2010174085026"]
]
self.assertItemsEqual(expected_filenames, filenames)
def testReadKeplerLightCurve(self):
filenames = [
os.path.join(self.data_dir, "0114/011442793/kplr011442793-%s_llc.fits")
% q for q in ["2009350155506", "2010009091648", "2010174085026"]
os.path.join(self.data_dir,
"0114/011442793/kplr011442793-{}_llc.fits".format(q))
for q in ["2009350155506", "2010009091648", "2010174085026"]
]
all_time, all_flux = kepler_io.read_kepler_light_curve(filenames)
self.assertLen(all_time, 3)
......@@ -120,6 +144,55 @@ class KeplerIoTest(absltest.TestCase):
self.assertLen(all_time[2], 4486)
self.assertLen(all_flux[2], 4486)
for time, flux in zip(all_time, all_flux):
self.assertTrue(np.isfinite(time).all())
self.assertTrue(np.isfinite(flux).all())
def testReadKeplerLightCurveScrambled(self):
filenames = [
os.path.join(self.data_dir,
"0114/011442793/kplr011442793-{}_llc.fits".format(q))
for q in ["2009350155506", "2010009091648", "2010174085026"]
]
all_time, all_flux = kepler_io.read_kepler_light_curve(
filenames, scramble_type="SCR1")
self.assertLen(all_time, 3)
self.assertLen(all_flux, 3)
# Arrays are shorter than above due to separation of time and flux NaNs.
self.assertLen(all_time[0], 4344)
self.assertLen(all_flux[0], 4344)
self.assertLen(all_time[1], 4041)
self.assertLen(all_flux[1], 4041)
self.assertLen(all_time[2], 1008)
self.assertLen(all_flux[2], 1008)
for time, flux in zip(all_time, all_flux):
self.assertTrue(np.isfinite(time).all())
self.assertTrue(np.isfinite(flux).all())
def testReadKeplerLightCurveScrambledInterpolateMissingTime(self):
filenames = [
os.path.join(self.data_dir,
"0114/011442793/kplr011442793-{}_llc.fits".format(q))
for q in ["2009350155506", "2010009091648", "2010174085026"]
]
all_time, all_flux = kepler_io.read_kepler_light_curve(
filenames, scramble_type="SCR1", interpolate_missing_time=True)
self.assertLen(all_time, 3)
self.assertLen(all_flux, 3)
self.assertLen(all_time[0], 4486)
self.assertLen(all_flux[0], 4486)
self.assertLen(all_time[1], 4134)
self.assertLen(all_flux[1], 4134)
self.assertLen(all_time[2], 1008)
self.assertLen(all_flux[2], 1008)
for time, flux in zip(all_time, all_flux):
self.assertTrue(np.isfinite(time).all())
self.assertTrue(np.isfinite(flux).all())
if __name__ == "__main__":
FLAGS.test_srcdir = ""
......
......@@ -32,16 +32,16 @@ def median_filter(x, y, num_bins, bin_width=None, x_min=None, x_max=None):
Args:
x: 1D array of x-coordinates sorted in ascending order. Must have at least 2
elements, and all elements cannot be the same value.
elements, and all elements cannot be the same value.
y: 1D array of y-coordinates with the same size as x.
num_bins: The number of intervals to divide the x-axis into. Must be at
least 2.
least 2.
bin_width: The width of each bin on the x-axis. Must be positive, and less
than x_max - x_min. Defaults to (x_max - x_min) / num_bins.
than x_max - x_min. Defaults to (x_max - x_min) / num_bins.
x_min: The inclusive leftmost value to consider on the x-axis. Must be less
than or equal to the largest value of x. Defaults to min(x).
than or equal to the largest value of x. Defaults to min(x).
x_max: The exclusive rightmost value to consider on the x-axis. Must be
greater than x_min. Defaults to max(x).
greater than x_min. Defaults to max(x).
Returns:
1D NumPy array of size num_bins containing the median y-values of uniformly
......@@ -51,35 +51,35 @@ def median_filter(x, y, num_bins, bin_width=None, x_min=None, x_max=None):
ValueError: If an argument has an inappropriate value.
"""
if num_bins < 2:
raise ValueError("num_bins must be at least 2. Got: %d" % num_bins)
raise ValueError("num_bins must be at least 2. Got: {}".format(num_bins))
# Validate the lengths of x and y.
x_len = len(x)
if x_len < 2:
raise ValueError("len(x) must be at least 2. Got: %s" % x_len)
raise ValueError("len(x) must be at least 2. Got: {}".format(x_len))
if x_len != len(y):
raise ValueError("len(x) (got: %d) must equal len(y) (got: %d)" % (x_len,
len(y)))
raise ValueError("len(x) (got: {}) must equal len(y) (got: {})".format(
x_len, len(y)))
# Validate x_min and x_max.
x_min = x_min if x_min is not None else x[0]
x_max = x_max if x_max is not None else x[-1]
if x_min >= x_max:
raise ValueError("x_min (got: %d) must be less than x_max (got: %d)" %
(x_min, x_max))
raise ValueError("x_min (got: {}) must be less than x_max (got: {})".format(
x_min, x_max))
if x_min > x[-1]:
raise ValueError(
"x_min (got: %d) must be less than or equal to the largest value of x "
"(got: %d)" % (x_min, x[-1]))
"x_min (got: {}) must be less than or equal to the largest value of x "
"(got: {})".format(x_min, x[-1]))
# Validate bin_width.
bin_width = bin_width if bin_width is not None else (x_max - x_min) / num_bins
if bin_width <= 0:
raise ValueError("bin_width must be positive. Got: %d" % bin_width)
raise ValueError("bin_width must be positive. Got: {}".format(bin_width))
if bin_width >= x_max - x_min:
raise ValueError(
"bin_width (got: %d) must be less than x_max - x_min (got: %d)" %
(bin_width, x_max - x_min))
"bin_width (got: {}) must be less than x_max - x_min (got: {})".format(
bin_width, x_max - x_min))
bin_spacing = (x_max - x_min - bin_width) / (num_bins - 1)
......
......@@ -124,5 +124,5 @@ class MedianFilterTest(absltest.TestCase):
np.testing.assert_array_equal([7, 1, 5, 2, 3], result)
if __name__ == '__main__':
if __name__ == "__main__":
absltest.main()
......@@ -62,7 +62,7 @@ class Event(object):
other_event: An Event.
period_rtol: Relative tolerance in matching the periods.
t0_durations: Tolerance in matching the t0 values, in units of the other
Event's duration.
Event's duration.
Returns:
True if this Event is the same as other_event, within the given tolerance.
......
......@@ -17,7 +17,7 @@ def robust_mean(y, cut):
Args:
y: 1D numpy array. Assumed to be normally distributed with outliers.
cut: Points more than this number of standard deviations from the median are
ignored.
ignored.
Returns:
mean: A robust estimate of the mean of y.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册