提交 ab6e3fa2 编写于 作者: A Alex Tamkin 提交者: Christopher Shallue

Add AstroWaveNet model, a generative model of astronomical light curves....

Add AstroWaveNet model, a generative model of astronomical light curves. AstroWaveNet's hidden states can be used in a semi-supervised fashion for downstream tasks like finding exoplanets.

PiperOrigin-RevId: 214073725
上级 650f0a3d
"""A TensorFlow model for generative modeling of light curves."""
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "configurations",
srcs = ["configurations.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "astrowavenet",
srcs = [
"astrowavenet.py",
],
srcs_version = "PY2AND3",
)
py_test(
name = "astrowavenet_test",
size = "small",
srcs = [
"astrowavenet_test.py",
],
srcs_version = "PY2AND3",
deps = [
":astrowavenet",
":configurations",
"//astronet/util:configdict",
],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A TensorFlow WaveNet model for generative modeling of light curves.
Implementation based on "WaveNet: A Generative Model of Raw Audio":
https://arxiv.org/abs/1609.03499
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def _shift_right(x):
"""Shifts the input Tensor right by one index along the second dimension.
Pads the front with zeros and discards the last element.
Args:
x: Input three-dimensional tf.Tensor.
Returns:
Padded, shifted tensor of same shape as input.
"""
x_padded = tf.pad(x, [[0, 0], [1, 0], [0, 0]])
return x_padded[:, :-1, :]
class AstroWaveNet(object):
"""A TensorFlow model for generative modeling of light curves."""
def __init__(self, features, hparams, mode):
"""Basic setup.
The actual TensorFlow graph is constructed in build().
Args:
features: A dictionary containing "autoregressive_input" and
"conditioning_stack", each of which is a named input Tensor. All
features have dtype float32 and shape [batch_size, length, dim].
hparams: A ConfigDict of hyperparameters for building the model.
mode: A tf.estimator.ModeKeys to specify whether the graph should be built
for training, evaluation or prediction.
Raises:
ValueError: If mode is invalid.
"""
valid_modes = [
tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
tf.estimator.ModeKeys.PREDICT
]
if mode not in valid_modes:
raise ValueError('Expected mode in {}. Got: {}'.format(valid_modes, mode))
self.hparams = hparams
self.mode = mode
self.autoregressive_input = features['autoregressive_input']
self.conditioning_stack = features['conditioning_stack']
self.weights = features.get('weights')
self.network_output = None # Sum of skip connections from dilation stack.
self.predicted_distributions = None # Predicted distribution for examples.
self.batch_losses = None # Loss for each predicted distribution in batch.
self.per_example_loss = None # Loss for each example in batch.
self.total_loss = None # Overall loss for the batch.
self.global_step = None # Global step Tensor.
def causal_conv_layer(self, x, output_size, kernel_width, dilation_rate=1):
"""Applies a dialated causal convolution to the input.
Args:
x: tf.Tensor; Input tensor.
output_size: int; Number of output filters for the convolution.
kernel_width: int; Width of the 1D convolution window.
dilation_rate: int; Dilation rate of the layer.
Returns:
Resulting tf.Tensor after applying the convolution.
"""
causal_conv_op = tf.keras.layers.Conv1D(
output_size,
kernel_width,
padding='causal',
dilation_rate=dilation_rate,
name='causal_conv')
return causal_conv_op(x)
def conv_1x1_layer(self, x, output_size, activation=None):
"""Applies a 1x1 convolution to the input.
Args:
x: tf.Tensor; Input tensor.
output_size: int; Number of output filters for the 1x1 convolution.
activation: Activation function to apply (e.g. 'relu').
Returns:
Resulting tf.Tensor after applying the 1x1 convolution.
"""
conv_1x1_op = tf.keras.layers.Conv1D(
output_size, 1, activation=activation, name='conv1x1')
return conv_1x1_op(x)
def gated_residual_layer(self, x, dilation_rate):
"""Creates a gated, dilated convolutional layer with a residual connnection.
Args:
x: tf.Tensor; Input tensor
dilation_rate: int; Dilation rate of the layer.
Returns:
skip_connection: tf.Tensor; Skip connection to network_output layer.
residual_connection: tf.Tensor; Sum of learned residual and input tensor.
"""
with tf.variable_scope('filter'):
x_filter_conv = self.causal_conv_layer(x, int(
x.shape[-1]), self.hparams.dilation_kernel_width, dilation_rate)
cond_filter_conv = self.conv_1x1_layer(self.conditioning_stack,
int(x.shape[-1]))
with tf.variable_scope('gate'):
x_gate_conv = self.causal_conv_layer(x, int(
x.shape[-1]), self.hparams.dilation_kernel_width, dilation_rate)
cond_gate_conv = self.conv_1x1_layer(self.conditioning_stack,
int(x.shape[-1]))
gated_activation = (
tf.tanh(x_filter_conv + cond_filter_conv) *
tf.sigmoid(x_gate_conv + cond_gate_conv))
with tf.variable_scope('residual'):
residual = self.conv_1x1_layer(gated_activation, int(x.shape[-1]))
with tf.variable_scope('skip'):
skip_connection = self.conv_1x1_layer(gated_activation,
self.hparams.skip_output_dim)
return skip_connection, x + residual
def build_network(self):
"""Builds WaveNet network.
This consists of:
1) An initial causal convolution,
2) The dialation stack, and
3) Summing of skip connections
The network output can then be used to predict various output distributions.
Inputs:
self.autoregressive_input
self.conditioning_stack
Outputs:
self.network_output; tf.Tensor
"""
skip_connections = []
x = _shift_right(self.autoregressive_input)
with tf.variable_scope('preprocess'):
x = self.causal_conv_layer(x, self.hparams.preprocess_output_size,
self.hparams.preprocess_kernel_width)
for i in range(self.hparams.num_residual_blocks):
with tf.variable_scope('block_{}'.format(i)):
for dilation_rate in self.hparams.dilation_rates:
with tf.variable_scope('dilation_{}'.format(dilation_rate)):
skip_connection, x = self.gated_residual_layer(x, dilation_rate)
skip_connections.append(skip_connection)
self.network_output = tf.add_n(skip_connections)
def dist_params_layer(self, x, outputs_size):
"""Converts x to the correct shape for populating a distribution object.
Args:
x: A Tensor of shape [batch_size, time_series_length, num_features].
outputs_size: The number of parameters needed to specify all the
distributions in the output. E.g. 5*3=15 to specify 5 distributions with
3 parameters each.
Returns:
The parameters of each distribution, a tensor of shape [batch_size,
time_series_length, outputs_size].
"""
with tf.variable_scope('dist_params'):
conv_outputs = self.conv_1x1_layer(x, outputs_size)
return conv_outputs
def build_predictions(self):
"""Predicts output distribution from network outputs.
Runs the model through:
1) ReLU
2) 1x1 convolution
3) ReLU
4) 1x1 convolution
The result of the last convolution is used as the parameters of the
specified output distribution (currently either Categorical or Normal).
Inputs:
self.network_outputs
Outputs:
self.predicted_distributions
Raises:
ValueError: If distribution type is neither 'categorical' nor 'normal'.
"""
with tf.variable_scope('postprocess'):
network_output = tf.keras.activations.relu(self.network_output)
network_output = self.conv_1x1_layer(
network_output,
output_size=int(network_output.shape[-1]),
activation='relu')
num_dists = int(self.autoregressive_input.shape[-1])
if self.hparams.output_distribution.type == 'categorical':
num_classes = self.hparams.output_distribution.num_classes
dist_params = self.dist_params_layer(network_output,
num_dists * num_classes)
dist_shape = tf.concat(
[tf.shape(network_output)[:-1], [num_dists, num_classes]], 0)
dist_params = tf.reshape(dist_params, dist_shape)
dist = tf.distributions.Categorical(logits=dist_params)
elif self.hparams.output_distribution.type == 'normal':
dist_params = self.dist_params_layer(network_output, num_dists * 2)
loc, scale = tf.split(dist_params, 2, axis=-1)
# Ensure scale is positive.
scale = tf.nn.softplus(scale) + self.hparams.output_distribution.min_scale
dist = tf.distributions.Normal(loc, scale)
else:
raise ValueError('Unsupported distribution type {}'.format(
self.hparams.output_distribution.type))
self.predicted_distributions = dist
def build_losses(self):
"""Builds the training losses.
Inputs:
self.predicted_distributions
Outputs:
self.batch_losses
self.total_loss
"""
autoregressive_target = self.autoregressive_input
# Quantize the target if the output distribution is categorical.
if self.hparams.output_distribution.type == 'categorical':
min_val = self.hparams.output_distribution.min_quantization_value
max_val = self.hparams.output_distribution.max_quantization_value
num_classes = self.hparams.output_distribution.num_classes
clipped_target = tf.keras.backend.clip(autoregressive_target, min_val,
max_val)
quantized_target = tf.floor(
(clipped_target - min_val) / (max_val - min_val) * num_classes)
# Deal with the corner case where clipped_target equals max_val by mapping
# the label num_classes to num_classes - 1. Essentially, this makes the
# final quantized bucket a closed interval while all the other quantized
# buckets are half-open intervals.
quantized_target = tf.where(
quantized_target == num_classes,
tf.ones_like(quantized_target) * (num_classes - 1), quantized_target)
autoregressive_target = quantized_target
log_prob = self.predicted_distributions.log_prob(autoregressive_target)
weights = self.weights
if weights is None:
weights = tf.ones_like(log_prob)
weights_dim = len(weights.shape)
per_example_weight = tf.reduce_sum(weights, axis=range(1, weights_dim))
per_example_indicator = tf.to_float(tf.greater(per_example_weight, 0))
num_examples = tf.reduce_sum(
per_example_indicator, name='num_nonzero_weight_examples')
batch_losses = -log_prob * weights
losses_dim = len(batch_losses.shape)
per_example_loss_sum = tf.reduce_sum(
batch_losses, axis=range(1, losses_dim))
per_example_loss = tf.where(per_example_weight > 0,
per_example_loss_sum / per_example_weight,
tf.zeros_like(per_example_weight))
total_loss = tf.reduce_sum(per_example_loss) / num_examples
self.batch_losses = batch_losses
self.per_example_loss = per_example_loss
self.total_loss = total_loss
def build(self):
"""Creates all ops for training, evaluation or inference."""
self.global_step = tf.train.get_or_create_global_step()
self.build_network()
self.build_predictions()
if self.mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]:
self.build_losses()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for astrowavenet."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from astronet.util import configdict
from astrowavenet import astrowavenet
class AstrowavenetTest(tf.test.TestCase):
def assertShapeEquals(self, shape, tensor_or_array):
"""Asserts that a Tensor or Numpy array has the expected shape.
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape)
elif isinstance(tensor_or_array, (tf.Tensor, tf.Variable)):
self.assertAllEqual(shape, tensor_or_array.shape.as_list())
else:
raise TypeError('tensor_or_array must be a Tensor or Numpy ndarray')
def test_build_model(self):
batch_size = 11
time_series_length = 9
input_num_features = 8
context_num_features = 7
input_placeholder = tf.placeholder(
dtype=tf.float32,
shape=[None, time_series_length, input_num_features],
name='input')
context_placeholder = tf.placeholder(
dtype=tf.float32,
shape=[None, time_series_length, context_num_features],
name='context')
features = {
'autoregressive_input': input_placeholder,
'conditioning_stack': context_placeholder
}
mode = tf.estimator.ModeKeys.TRAIN
hparams = configdict.ConfigDict({
'dilation_kernel_width': 2,
'skip_output_dim': 6,
'preprocess_output_size': 3,
'preprocess_kernel_width': 5,
'num_residual_blocks': 2,
'dilation_rates': [1, 2, 4],
'output_distribution': {
'type': 'normal',
'min_scale': 0.001,
'num_classes': 256,
}
})
model = astrowavenet.AstroWaveNet(features, hparams, mode)
model.build()
variables = {v.op.name: v for v in tf.trainable_variables()}
# Verify variable shapes in two residual blocks.
var = variables['preprocess/causal_conv/kernel']
self.assertShapeEquals((5, 8, 3), var)
var = variables['preprocess/causal_conv/bias']
self.assertShapeEquals((3,), var)
var = variables['block_0/dilation_1/filter/causal_conv/kernel']
self.assertShapeEquals((2, 3, 3), var)
var = variables['block_0/dilation_1/filter/causal_conv/bias']
self.assertShapeEquals((3,), var)
var = variables['block_0/dilation_1/filter/conv1x1/kernel']
self.assertShapeEquals((1, 7, 3), var)
var = variables['block_0/dilation_1/filter/conv1x1/bias']
self.assertShapeEquals((3,), var)
var = variables['block_0/dilation_1/gate/causal_conv/kernel']
self.assertShapeEquals((2, 3, 3), var)
var = variables['block_0/dilation_1/gate/causal_conv/bias']
self.assertShapeEquals((3,), var)
var = variables['block_0/dilation_1/gate/conv1x1/kernel']
self.assertShapeEquals((1, 7, 3), var)
var = variables['block_0/dilation_1/gate/conv1x1/bias']
self.assertShapeEquals((3,), var)
var = variables['block_0/dilation_1/residual/conv1x1/kernel']
self.assertShapeEquals((1, 3, 3), var)
var = variables['block_0/dilation_1/residual/conv1x1/bias']
self.assertShapeEquals((3,), var)
var = variables['block_0/dilation_1/skip/conv1x1/kernel']
self.assertShapeEquals((1, 3, 6), var)
var = variables['block_0/dilation_1/skip/conv1x1/bias']
self.assertShapeEquals((6,), var)
var = variables['block_1/dilation_4/filter/causal_conv/kernel']
self.assertShapeEquals((2, 3, 3), var)
var = variables['block_1/dilation_4/filter/causal_conv/bias']
self.assertShapeEquals((3,), var)
var = variables['block_1/dilation_4/filter/conv1x1/kernel']
self.assertShapeEquals((1, 7, 3), var)
var = variables['block_1/dilation_4/filter/conv1x1/bias']
self.assertShapeEquals((3,), var)
var = variables['block_1/dilation_4/gate/causal_conv/kernel']
self.assertShapeEquals((2, 3, 3), var)
var = variables['block_1/dilation_4/gate/causal_conv/bias']
self.assertShapeEquals((3,), var)
var = variables['block_1/dilation_4/gate/conv1x1/kernel']
self.assertShapeEquals((1, 7, 3), var)
var = variables['block_1/dilation_4/gate/conv1x1/bias']
self.assertShapeEquals((3,), var)
var = variables['block_1/dilation_4/residual/conv1x1/kernel']
self.assertShapeEquals((1, 3, 3), var)
var = variables['block_1/dilation_4/residual/conv1x1/bias']
self.assertShapeEquals((3,), var)
var = variables['block_1/dilation_4/skip/conv1x1/kernel']
self.assertShapeEquals((1, 3, 6), var)
var = variables['block_1/dilation_4/skip/conv1x1/bias']
self.assertShapeEquals((6,), var)
var = variables['postprocess/conv1x1/kernel']
self.assertShapeEquals((1, 6, 6), var)
var = variables['postprocess/conv1x1/bias']
self.assertShapeEquals((6,), var)
var = variables['dist_params/conv1x1/kernel']
self.assertShapeEquals((1, 6, 16), var)
var = variables['dist_params/conv1x1/bias']
self.assertShapeEquals((16,), var)
# Verify total number of trainable parameters.
num_preprocess_params = (
hparams.preprocess_kernel_width * input_num_features *
hparams.preprocess_output_size + hparams.preprocess_output_size)
num_gated_params = (
hparams.dilation_kernel_width * hparams.preprocess_output_size *
hparams.preprocess_output_size + hparams.preprocess_output_size +
1 * context_num_features * hparams.preprocess_output_size +
hparams.preprocess_output_size) * 2
num_residual_params = (
1 * hparams.preprocess_output_size * hparams.preprocess_output_size +
hparams.preprocess_output_size)
num_skip_params = (
1 * hparams.preprocess_output_size * hparams.skip_output_dim +
hparams.skip_output_dim)
num_block_params = (
num_gated_params + num_residual_params + num_skip_params) * len(
hparams.dilation_rates) * hparams.num_residual_blocks
num_postprocess_params = (
1 * hparams.skip_output_dim * hparams.skip_output_dim +
hparams.skip_output_dim)
num_dist_params = (1 * hparams.skip_output_dim * 2 * input_num_features +
2 * input_num_features)
total_params = (
num_preprocess_params + num_block_params + num_postprocess_params +
num_dist_params)
total_retrieved_params = 0
for v in tf.trainable_variables():
total_retrieved_params += np.prod(v.shape)
self.assertEqual(total_params, total_retrieved_params)
# Verify model runs and outputs losses of correct shape.
scaffold = tf.train.Scaffold()
scaffold.finalize()
with self.cached_session() as sess:
sess.run([scaffold.init_op, scaffold.local_init_op])
step = sess.run(model.global_step)
self.assertEqual(0, step)
feed_dict = {
input_placeholder:
np.random.random((batch_size, time_series_length,
input_num_features)),
context_placeholder:
np.random.random((batch_size, time_series_length,
context_num_features))
}
batch_losses, per_example_loss, total_loss = sess.run(
[model.batch_losses, model.per_example_loss, model.total_loss],
feed_dict=feed_dict)
self.assertShapeEquals(
(batch_size, time_series_length, input_num_features), batch_losses)
self.assertShapeEquals((batch_size,), per_example_loss)
self.assertShapeEquals((), total_loss)
def test_build_model_categorical(self):
batch_size = 11
time_series_length = 9
input_num_features = 8
context_num_features = 7
input_placeholder = tf.placeholder(
dtype=tf.float32,
shape=[None, time_series_length, input_num_features],
name='input')
context_placeholder = tf.placeholder(
dtype=tf.float32,
shape=[None, time_series_length, context_num_features],
name='context')
features = {
'autoregressive_input': input_placeholder,
'conditioning_stack': context_placeholder
}
mode = tf.estimator.ModeKeys.TRAIN
hparams = configdict.ConfigDict({
'dilation_kernel_width': 2,
'skip_output_dim': 6,
'preprocess_output_size': 3,
'preprocess_kernel_width': 5,
'num_residual_blocks': 2,
'dilation_rates': [1, 2, 4],
'output_distribution': {
'type': 'categorical',
'num_classes': 256,
'min_quantization_value': -1,
'max_quantization_value': 1
}
})
model = astrowavenet.AstroWaveNet(features, hparams, mode)
model.build()
variables = {v.op.name: v for v in tf.trainable_variables()}
var = variables['dist_params/conv1x1/kernel']
self.assertShapeEquals(
(1, hparams.skip_output_dim,
hparams.output_distribution.num_classes * input_num_features), var)
var = variables['dist_params/conv1x1/bias']
self.assertShapeEquals(
(hparams.output_distribution.num_classes * input_num_features,), var)
# Verify model runs and outputs losses of correct shape.
scaffold = tf.train.Scaffold()
scaffold.finalize()
with self.cached_session() as sess:
sess.run([scaffold.init_op, scaffold.local_init_op])
step = sess.run(model.global_step)
self.assertEqual(0, step)
feed_dict = {
input_placeholder:
np.random.random((batch_size, time_series_length,
input_num_features)),
context_placeholder:
np.random.random((batch_size, time_series_length,
context_num_features))
}
batch_losses, per_example_loss, total_loss = sess.run(
[model.batch_losses, model.per_example_loss, model.total_loss],
feed_dict=feed_dict)
self.assertShapeEquals(
(batch_size, time_series_length, input_num_features), batch_losses)
self.assertShapeEquals((batch_size,), per_example_loss)
self.assertShapeEquals((), total_loss)
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurations for model building, training and evaluation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def base():
"""Returns the base config for model building, training and evaluation."""
return {
# Hyperparameters for building and training the model.
"hparams": {
"batch_size": 64,
"dilation_kernel_width": 2,
"skip_output_dim": 10,
"preprocess_output_size": 3,
"preprocess_kernel_width": 10,
"num_residual_blocks": 4,
"dilation_rates": [1, 2, 4, 8, 16],
"output_distribution": {
"type": "normal",
"min_scale": 0.001
},
# Learning rate parameters.
"learning_rate": 1e-6,
"learning_rate_decay_steps": 0,
"learning_rate_decay_factor": 0,
"learning_rate_decay_staircase": True,
# Optimizer for training the model.
"optimizer": "adam",
# If not None, gradient norms will be clipped to this value.
"clip_gradient_norm": 1,
}
}
def categorical():
"""Returns a config for models with a categorical output distribution.
Input values will be clipped to {min,max}_value_for_quantization, then
linearly split into num_classes.
"""
config = base()
config["hparams"]["output_distribution"] = {
"type": "categorical",
"num_classes": 256,
"min_quantization_value": -1,
"max_quantization_value": 1
}
return config
def get_config(config_name):
"""Returns config correspnding to provided name."""
if config_name in ["base", "normal"]:
return base()
elif config_name == "categorical":
return categorical()
else:
raise ValueError("Unrecognized config name: {}".format(config_name))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册