提交 88165b4d 编写于 作者: I Ian Langmore 提交者: TensorFlower Gardener

LinearOperator:

* contrib/linalg/ started
* linear_operator.py added containing base class LinearOperator.
* linear_operator_diag.py added containing first derived class
  LinearOperatorDiag.
* A base class for tests is also added.
Change: 139866369
上级 833c706e
......@@ -103,6 +103,7 @@ filegroup(
"//tensorflow/contrib/layers/kernels:all_files",
"//tensorflow/contrib/learn:all_files",
"//tensorflow/contrib/learn/python/learn/datasets:all_files",
"//tensorflow/contrib/linalg:all_files",
"//tensorflow/contrib/linear_optimizer:all_files",
"//tensorflow/contrib/lookup:all_files",
"//tensorflow/contrib/losses:all_files",
......
......@@ -27,6 +27,7 @@ py_library(
"//tensorflow/contrib/labeled_tensor",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
"//tensorflow/contrib/lookup:lookup_py",
"//tensorflow/contrib/losses:losses_py",
......
......@@ -32,6 +32,7 @@ from tensorflow.contrib import integrate
from tensorflow.contrib import labeled_tensor
from tensorflow.contrib import layers
from tensorflow.contrib import learn
from tensorflow.contrib import linalg
from tensorflow.contrib import linear_optimizer
from tensorflow.contrib import lookup
from tensorflow.contrib import losses
......
# Description:
# Contains ops for statistical distributions (with pdf, cdf, sample, etc...).
# APIs here are meant to evolve over time.
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
cuda_py_tests(
name = "linear_operator_test",
size = "small",
srcs = ["python/kernel_tests/linear_operator_test.py"],
additional_deps = [
":linalg_py",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
name = "linear_operator_diag_test",
size = "small",
srcs = ["python/kernel_tests/linear_operator_diag_test.py"],
additional_deps = [
":linalg_py",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
shard_count = 5,
)
py_library(
name = "linalg_py",
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
srcs_version = "PY2AND3",
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Linear algebra libraries for TensorFlow.
## `LinearOperator`
Subclasses of `LinearOperator` provide a access to common methods on a
(batch) matrix, without the need to materialize the matrix. This allows:
* Matrix free computations
* Different operators to take advantage of special strcture, while providing a
consistent API to users.
### Base class
@@LinearOperator
### Individual operators
@@LinearOperatorDiag
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
from tensorflow.contrib.linalg.python.ops.linear_operator import *
from tensorflow.contrib.linalg.python.ops.linear_operator_diag import *
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ops module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.linalg.python.ops import linear_operator_test_util
linalg = tf.contrib.linalg
tf.set_random_seed(23)
class LinearOperatorDiagtest(
linear_operator_test_util.LinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
@property
def _dtypes_to_test(self):
return [tf.float32, tf.float64]
@property
def _shapes_to_test(self):
# non-batch operators (n, n) and batch operators.
return [(0, 0), (1, 1), (1, 3, 3), (3, 2, 2), (2, 1, 3, 3)]
def _make_rhs(self, operator):
# This operator is square, so rhs and x will have same shape.
return self._make_x(operator)
def _make_x(self, operator):
# Return the number of systems to solve, R, equal to 1 or 2.
r = self._get_num_systems(operator)
# If operator.shape = [B1,...,Bb, N, N] this returns a random matrix of
# shape [B1,...,Bb, N, R], R = 1 or 2.
if operator.shape.is_fully_defined():
batch_shape = operator.batch_shape.as_list()
n = operator.domain_dimension.value
rhs_shape = batch_shape + [n, r]
else:
batch_shape = operator.batch_shape_dynamic()
n = operator.domain_dimension_dynamic()
rhs_shape = tf.concat(0, (batch_shape, [n, r]))
return tf.random_normal(shape=rhs_shape, dtype=operator.dtype)
def _get_num_systems(self, operator):
"""Get some number, either 1 or 2, depending on operator."""
if operator.tensor_rank is None or operator.tensor_rank % 2:
return 1
else:
return 2
def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder):
shape = list(shape)
diag_shape = shape[:-1]
diag = tf.random_normal(diag_shape, dtype=dtype)
diag_ph = tf.placeholder(dtype=dtype)
if use_placeholder:
# Evaluate the diag here because (i) you cannot feed a tensor, and (ii)
# diag is random and we want the same value used for both mat and
# feed_dict.
diag = diag.eval()
mat = tf.matrix_diag(diag)
operator = linalg.LinearOperatorDiag(diag_ph)
feed_dict = {diag_ph: diag}
else:
mat = tf.matrix_diag(diag)
operator = linalg.LinearOperatorDiag(diag)
feed_dict = None
return operator, mat, feed_dict
def test_assert_positive_definite(self):
# Singlular matrix with one positive eigenvalue and one zero eigenvalue.
with self.test_session():
diag = [1.0, -1.0]
operator = linalg.LinearOperatorDiag(diag)
with self.assertRaisesOpError("was not positive definite"):
operator.assert_positive_definite().run()
def test_assert_non_singular(self):
# Singlular matrix with one positive eigenvalue and one zero eigenvalue.
with self.test_session():
diag = [1.0, 0.0]
operator = linalg.LinearOperatorDiag(diag)
with self.assertRaisesOpError("Singular operator"):
operator.assert_non_singular().run()
def test_broadcast_apply_and_solve(self):
# These cannot be done in the automated (base test class) tests since they
# test shapes that tf.batch_matmul cannot handle.
# In particular, tf.batch_matmul does not broadcast.
with self.test_session() as sess:
x = tf.random_normal(shape=(2, 2, 3, 4))
# This LinearOperatorDiag will be brodacast to (2, 2, 3, 3) during solve
# and apply with 'x' as the argument.
diag = tf.random_uniform(shape=(2, 1, 3))
operator = linalg.LinearOperatorDiag(diag)
self.assertAllEqual((2, 1, 3, 3), operator.shape)
# Create a batch matrix with the broadcast shape of operator.
diag_broadcast = tf.concat(1, (diag, diag))
mat = tf.matrix_diag(diag_broadcast)
self.assertAllEqual((2, 2, 3, 3), mat.get_shape()) # being pedantic.
operator_apply = operator.apply(x)
mat_apply = tf.batch_matmul(mat, x)
self.assertAllEqual(operator_apply.get_shape(), mat_apply.get_shape())
self.assertAllClose(*sess.run([operator_apply, mat_apply]))
operator_solve = operator.solve(x)
mat_solve = tf.matrix_solve(mat, x)
self.assertAllEqual(operator_solve.get_shape(), mat_solve.get_shape())
self.assertAllClose(*sess.run([operator_solve, mat_solve]))
if __name__ == "__main__":
tf.test.main()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
linalg = tf.contrib.linalg
class LinearOperatorShape(linalg.LinearOperator):
"""LinearOperator that implements the methods ._shape and _shape_dynamic."""
def __init__(self,
shape,
is_non_singular=None,
is_self_adjoint=None,
is_positive_definite=None):
self._stored_shape = shape
super(LinearOperatorShape, self).__init__(
dtype=tf.float32,
graph_parents=None,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,)
def _shape(self):
return tf.TensorShape(self._stored_shape)
def _shape_dynamic(self):
return tf.constant(self._stored_shape, dtype=tf.int32)
class LinearOperatorTest(tf.test.TestCase):
def test_all_shape_properties_defined_by_the_one_property_shape(self):
shape = (1, 2, 3, 4)
operator = LinearOperatorShape(shape)
self.assertAllEqual(shape, operator.shape)
self.assertAllEqual(4, operator.tensor_rank)
self.assertAllEqual((1, 2), operator.batch_shape)
self.assertAllEqual(4, operator.domain_dimension)
self.assertAllEqual(3, operator.range_dimension)
def test_all_shape_methods_defined_by_the_one_method_shape(self):
with self.test_session():
shape = (1, 2, 3, 4)
operator = LinearOperatorShape(shape)
self.assertAllEqual(shape, operator.shape_dynamic().eval())
self.assertAllEqual(4, operator.tensor_rank_dynamic().eval())
self.assertAllEqual((1, 2), operator.batch_shape_dynamic().eval())
self.assertAllEqual(4, operator.domain_dimension_dynamic().eval())
self.assertAllEqual(3, operator.range_dimension_dynamic().eval())
def test_is_x_properties(self):
operator = LinearOperatorShape(
shape=(2, 2),
is_non_singular=False,
is_self_adjoint=True,
is_positive_definite=False)
self.assertFalse(operator.is_non_singular)
self.assertTrue(operator.is_self_adjoint)
self.assertFalse(operator.is_positive_definite)
if __name__ == '__main__':
tf.test.main()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for linear operators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
from tensorflow.contrib import framework as contrib_framework
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
__all__ = ["LinearOperator"]
class LinearOperator(object):
"""Base class defining a [batch of] linear operator[s].
Subclasses of `LinearOperator` provide a access to common methods on a
(batch) matrix, without the need to materialize the matrix. This allows:
* Matrix free computations
* Operators that take advantage of special structure, while providing a
consistent API to users.
### Subclassing
To enable a public method, subclasses should implement the leading-underscore
version of the method. The argument signature should be identical except for
the omission of `name="..."`. For example, to enable
`apply(x, adjoint=False, name="apply")` a subclass should implement
`_apply(x, adjoint=False)`.
### Performance contract
Subclasses should implement a method only if it can be done with a reasonable
performance increase over generic dense operations, either in time, parallel
scalability, or memory usage. For example, if the determinant can only be
computed using `tf.matrix_determinant(self.to_dense())`, then determinants
should not be implemented.
Class docstrings should contain an explanation of computational complexity.
Since this is a high-performance library, attention should be paid to detail,
and explanations can include constants as well as Big-O notation.
### Shape compatibility
`LinearOperator` sub classes should operate on a [batch] matrix with
compatible shape. Class docstrings should define what is meant by compatible
shape. Some sub-classes may not support batching.
An example is:
`x` is a batch matrix with compatible shape for `apply` if
```
operator.shape = [B1,...,Bb] + [M, N], b >= 0,
x.shape = [B1,...,Bb] + [N, R]
```
`rhs` is a batch matrix with compatible shape for `solve` if
```
operator.shape = [B1,...,Bb] + [M, N], b >= 0,
rhs.shape = [B1,...,Bb] + [M, R]
```
### Example docstring for subclasses.
This operator acts like a (batch) matrix `A` with shape
`[B1,...,Bb, M, N]` for some `b >= 0`. The first `b` indices index a
batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
an `m x n` matrix. Again, this matrix `A` may not be materialized, but for
purposes of identifying and working with compatible arguments the shape is
relevant.
Examples:
```python
some_tensor = ... shape = ????
operator = MyLinOp(some_tensor)
operator.shape()
==> [2, 4, 4]
operator.log_determinant()
==> Shape [2] Tensor
x = ... Shape [2, 4, 5] Tensor
operator.apply(x)
==> Shape [2, 4, 5] Tensor
```
### Shape compatibility
This operator acts on batch matrices with compatible shape.
FILL IN WHAT IS MEANT BY COMPATIBLE SHAPE
### Performance
FILL THIS IN
"""
def __init__(self,
dtype,
graph_parents=None,
is_non_singular=None,
is_self_adjoint=None,
is_positive_definite=None,
name=None):
"""Initialize the `LinearOperator`.
**This is a private method for subclass use.**
**Subclasses should copy-paste this `__init__` documentation.**
For `X = non_singular, self_adjoint` etc...
`is_X` is a Python `bool` initialization argument with the following meaning
* If `is_X == True`, callers should expect the operator to have the
attribute `X`. This is a promise that should be fulfilled, but is *not* a
runtime assert. Issues, such as floating point error, could mean the
operator violates this promise.
* If `is_X == False`, callers should expect the operator to not have `X`.
* If `is_X == None` (the default), callers should have no expectation either
way.
Args:
dtype: The type of the this `LinearOperator`. Arguments to `apply` and
`solve` will have to be this type.
graph_parents: Python list of graph prerequisites of this `LinearOperator`
Typically tensors that are passed during initialization.
is_non_singular: Expect that this operator is non-singular.
is_self_adjoint: Expect that this operator is equal to its hermitian
transpose. If `dtype` is real, this is equivalent to being symmetric.
is_positive_definite: Expect that this operator is positive definite.
name: A name for this `LinearOperator`. Default: subclass name.
Raises:
ValueError: if any member of graph_parents is `None` or not a `Tensor`.
"""
if is_positive_definite and not is_self_adjoint:
raise ValueError(
"A positive definite matrix is by definition self adjoint")
if is_positive_definite and not is_non_singular:
raise ValueError(
"A positive definite matrix is by definition non-singular")
graph_parents = [] if graph_parents is None else graph_parents
for i, t in enumerate(graph_parents):
if t is None or not contrib_framework.is_tensor(t):
raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
self._dtype = dtype
self._graph_parents = graph_parents
self._is_non_singular = is_non_singular
self._is_self_adjoint = is_self_adjoint
self._is_positive_definite = is_positive_definite
self._name = name or type(self).__name__
@contextlib.contextmanager
def _name_scope(self, name=None, values=None):
"""Helper function to standardize op scope."""
with ops.name_scope(self.name):
with ops.name_scope(
name, values=((values or []) + self._graph_parents)) as scope:
yield scope
@property
def dtype(self):
"""The `DType` of `Tensor`s handled by this `LinearOperator`."""
return self._dtype
@property
def name(self):
"""Name prepended to all ops created by this `LinearOperator`."""
return self._name
@property
def graph_parents(self):
"""List of graph dependencies of this `LinearOperator`."""
return self._graph_parents
@property
def is_non_singular(self):
return self._is_non_singular
@property
def is_self_adjoint(self):
return self._is_self_adjoint
@property
def is_positive_definite(self):
return self._is_positive_definite
def _shape(self):
# Write this in derived class to enable all static shape methods.
raise NotImplementedError("_shape is not implemented.")
@property
def shape(self):
"""`TensorShape` of this `LinearOperator`.
If this operator acts like the batch matrix `A` with
`A.shape = [B1,...,Bb, M, N]`, then this returns
`TensorShape([B1,...,Bb, M, N])`, equivalent to `A.get_shape()`.
Returns:
`TensorShape`, statically determined, may be undefined.
"""
return self._shape()
def _shape_dynamic(self):
raise NotImplementedError("_shape_dynamic is not implemented.")
def shape_dynamic(self, name="shape_dynamic"):
"""Shape of this `LinearOperator`, determined at runtime.
If this operator acts like the batch matrix `A` with
`A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding
`[B1,...,Bb, M, N]`, equivalent to `tf.shape(A)`.
Args:
name: A name for this `Op.
Returns:
`int32` `Tensor`
"""
with self._name_scope(name):
return self._shape_dynamic()
@property
def batch_shape(self):
"""`TensorShape` of batch dimensions of this `LinearOperator`.
If this operator acts like the batch matrix `A` with
`A.shape = [B1,...,Bb, M, N]`, then this returns
`TensorShape([B1,...,Bb])`, equivalent to `A.get_shape()[:-2]`
Returns:
`TensorShape`, statically determined, may be undefined.
"""
# Derived classes get this "for free" once .shape is implemented.
return self.shape[:-2]
def batch_shape_dynamic(self, name="batch_shape_dynamic"):
"""Shape of batch dimensions of this operator, determined at runtime.
If this operator acts like the batch matrix `A` with
`A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding
`[B1,...,Bb]`.
Args:
name: A name for this `Op.
Returns:
`int32` `Tensor`
"""
# Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name):
return array_ops.slice(self.shape_dynamic(), [0],
[self.tensor_rank_dynamic() - 2])
@property
def tensor_rank(self, name="tensor_rank"):
"""Rank (in the sense of tensors) of matrix corresponding to this operator.
If this operator acts like the batch matrix `A` with
`A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`.
Args:
name: A name for this `Op.
Returns:
Python integer, or None if the tensor rank is undefined.
"""
# Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name):
return self.shape.ndims
def tensor_rank_dynamic(self, name="tensor_rank_dynamic"):
"""Rank (in the sense of tensors) of matrix corresponding to this operator.
If this operator acts like the batch matrix `A` with
`A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`.
Args:
name: A name for this `Op.
Returns:
`int32` `Tensor`, determined at runtime.
"""
# Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name):
return array_ops.size(self.shape_dynamic())
@property
def domain_dimension(self):
"""Dimension (in the sense of vector spaces) of the domain of this operator.
If this operator acts like the batch matrix `A` with
`A.shape = [B1,...,Bb, M, N]`, then this returns `N`.
Returns:
Python integer if vector space dimension can be determined statically,
otherwise `None`.
"""
# Derived classes get this "for free" once .shape is implemented.
return self.shape[-1]
def domain_dimension_dynamic(self, name="domain_dimension_dynamic"):
"""Dimension (in the sense of vector spaces) of the domain of this operator.
Determined at runtime.
If this operator acts like the batch matrix `A` with
`A.shape = [B1,...,Bb, M, N]`, then this returns `N`.
Args:
name: A name for this `Op`.
Returns:
`int32` `Tensor`
"""
# Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name):
return array_ops.gather(self.shape_dynamic(),
self.tensor_rank_dynamic() - 1)
@property
def range_dimension(self):
"""Dimension (in the sense of vector spaces) of the range of this operator.
If this operator acts like the batch matrix `A` with
`A.shape = [B1,...,Bb, M, N]`, then this returns `M`.
Returns:
Python integer if vector space dimension can be determined statically,
otherwise `None`.
"""
# Derived classes get this "for free" once .shape is implemented.
return self.shape[-2]
def range_dimension_dynamic(self, name="range_dimension_dynamic"):
"""Dimension (in the sense of vector spaces) of the range of this operator.
Determined at runtime.
If this operator acts like the batch matrix `A` with
`A.shape = [B1,...,Bb, M, N]`, then this returns `M`.
Args:
name: A name for this `Op`.
Returns:
`int32` `Tensor`
"""
# Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name):
return array_ops.gather(self.shape_dynamic(),
self.tensor_rank_dynamic() - 2)
def _assert_non_singular(self):
raise NotImplementedError("assert_non_singular is not implemented.")
def assert_non_singular(self, name="assert_non_singular"):
"""Returns an `Op` that asserts this operator is non singular."""
with self._name_scope(name):
return self._assert_non_singular()
def _assert_positive_definite(self):
raise NotImplementedError("assert_positive_definite is not implemented.")
def assert_positive_definite(self, name="assert_positive_definite"):
"""Returns an `Op` that asserts this operator is positive definite."""
with self._name_scope(name):
return self._assert_positive_definite()
def _apply(self, x, adjoint=False):
raise NotImplementedError("_apply is not implemented.")
def apply(self, x, adjoint=False, name="apply"):
"""Transform `x` with left multiplication: `x --> Ax`.
Args:
x: `Tensor` with compatible shape and same `dtype` as `self`.
See class docstring for definition of compatibility.
adjoint: Python `bool`. If `True`, left multiply by the adjoint.
name: A name for this `Op.
Returns:
A `Tensor` with shape `[..., M, R]` and same `dtype` as `self`.
"""
with self._name_scope(name, values=[x]):
x = ops.convert_to_tensor(x, name="x")
return self._apply(x, adjoint=adjoint)
def _determinant(self):
raise NotImplementedError("_det is not implemented.")
def determinant(self, name="det"):
"""Determinant for every batch member.
Args:
name: A name for this `Op.
Returns:
`Tensor` with shape `self.batch_shape` and same `dtype` as `self`.
"""
with self._name_scope(name):
return self._determinant()
def _log_abs_determinant(self):
raise NotImplementedError("_log_abs_det is not implemented.")
def log_abs_determinant(self, name="log_abs_det"):
"""Log absolute value of determinant for every batch member.
Args:
name: A name for this `Op.
Returns:
`Tensor` with shape `self.batch_shape` and same `dtype` as `self`.
"""
with self._name_scope(name):
return self._log_abs_determinant()
def _solve(self, rhs, adjoint=False):
# Since this is an exact solve method for all rhs, this will only be
# available for non-singular (batch) operators, in particular the operator
# must be square.
raise NotImplementedError("_solve is not implemented.")
def solve(self, rhs, adjoint=False, name="solve"):
"""Solve `R` (batch) systems of equations exactly: `A X = rhs`.
Examples:
```python
# Create an operator acting like a 10 x 2 x 2 matrix.
operator = LinearOperator(...)
operator.shape # = 10 x 2 x 2
# Solve one linear system (R = 1) for every member of the length 10 batch.
RHS = ... # shape 10 x 2 x 1
X = operator.solve(RHS) # shape 10 x 2 x 1
# Solve five linear systems (R = 5) for every member of the length 10 batch.
RHS = ... # shape 10 x 2 x 5
X = operator.solve(RHS)
X[3, :, 2] # Solution to the linear system A[3, :, :] X = RHS[3, :, 2]
```
Args:
rhs: `Tensor` with same `dtype` as this operator and compatible shape.
See class docstring for definition of compatibility.
adjoint: Python `bool`. If `True`, solve the system involving the adjoint
of this `LinearOperator`.
name: A name scope to use for ops added by this method.
Returns:
`Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.
Raises:
ValueError: If self.is_non_singular is False.
"""
if self.is_non_singular is False:
raise ValueError(
"Exact solve cannot be called with an operator that is expected to "
"be singular.")
with self._name_scope(name, values=[rhs]):
rhs = ops.convert_to_tensor(rhs, name="rhs")
return self._solve(rhs, adjoint=adjoint)
def _to_dense(self):
raise NotImplementedError("_to_dense is not implemented.")
def to_dense(self, name="to_dense"):
"""Return a dense (batch) matrix representing this operator."""
with self._name_scope(name):
return self._to_dense()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""`LinearOperator` acting like a diagonal matrix."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.linalg.python.ops import linear_operator
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
__all__ = ["LinearOperatorDiag",]
class LinearOperatorDiag(linear_operator.LinearOperator):
"""`LinearOperator` acting like a [batch] square diagonal matrix.
This operator acts like a [batch] matrix `A` with shape
`[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
an `m x n` matrix. Again, this matrix `A` may not be materialized, but for
purposes of broadcasting this shape will be relevant.
`LinearOperatorDiag` is initialized with a (batch) vector.
```python
# Create a 2 x 2 diagonal linear operator.
diag = [1., -1.]
operator = LinearOperatorDiag(diag)
operator.to_dense()
==> [[1., 0.]
[0., -1.]]
operator.shape()
==> [2, 2]
operator.log_determinant()
==> scalar Tensor
x = ... Shape [2, 4] Tensor
operator.apply(x)
==> Shape [2, 4] Tensor
# Create a [2, 3] batch of 4 x 4 linear operators.
diag = tf.random_normal(shape=[2, 3, 4])
operator = LinearOperatorDiag(diag)
# Create a shape [2, 1, 4, 2] vector. Note that this shape is compatible
# since the batch dimensions, [2, 1], are brodcast to
# operator.batch_shape = [2, 3].
y = tf.random_normal(shape=[2, 1, 4, 2])
x = operator.solve(y)
==> operator.apply(x) = y
```
### Shape compatibility
This operator acts on [batch] matrix with compatible shape.
`x` is a batch matrix with compatible shape for `apply` and `solve` if
```
operator.shape = [B1,...,Bb] + [N, N], with b >= 0
x.shape = [C1,...,Cc] + [N, R],
and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
```
### Performance
Suppose `operator` is a `LinearOperatorDiag` is of shape `[N, N]`,
and `x.shape = [N, R]`. Then
* `operator.apply(x)` involves `N*R` multiplications.
* `operator.solve(x)` involves `N` divisions and `N*R` multiplications.
* `operator.determinant()` involves a size `N` `reduce_prod`.
If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
`[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
"""
def __init__(self,
diag,
is_non_singular=None,
is_self_adjoint=True,
is_positive_definite=None,
name="LinearOperatorDiag"):
"""Initialize a `LinearOperatorDiag`.
For `X = non_singular, self_adjoint` etc...
`is_X` is a Python `bool` initialization argument with the following meaning
* If `is_X == True`, callers should expect the operator to have the
attribute `X`. This is a promise that should be fulfilled, but is *not* a
runtime assert. Issues, such as floating point error, could mean the
operator violates this promise.
* If `is_X == False`, callers should expect the operator to not have `X`.
* If `is_X == None` (the default), callers should have no expectation either
way.
Args:
diag: Shape `[B1,...,Bb, N]` real float type `Tensor` with `b >= 0`,
`N >= 0`. The diagonal of the operator.
is_non_singular: Expect that this operator is non-singular.
is_self_adjoint: Expect that this operator is equal to its hermitian
transpose. Since this is a real (not complex) diagonal operator, it is
always self adjoint.
is_positive_definite: Expect that this operator is positive definite.
name: A name for this `LinearOperator`. Default: subclass name.
Raises:
ValueError: If `diag.dtype` is not floating point.
ValueError: If `is_self_adjoint` is not `True`.
"""
with ops.name_scope(name, values=[diag]):
self._diag = ops.convert_to_tensor(diag, name="diag")
if not self._diag.dtype.is_floating:
raise ValueError("Only real floating point matrices are supported.")
if not is_self_adjoint:
raise ValueError("A real diagonal matrix is always self adjoint.")
super(LinearOperatorDiag, self).__init__(
dtype=self._diag.dtype,
graph_parents=[self._diag],
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_non_singular,
name=name)
def _shape(self):
# If d_shape = [5, 3], we return [5, 3, 3].
d_shape = self._diag.get_shape()
return d_shape.concatenate(d_shape[-1:])
def _shape_dynamic(self):
d_shape = array_ops.shape(self._diag)
k = d_shape[-1]
return array_ops.concat(0, (d_shape, [k]))
def _assert_non_singular(self):
nonzero_diag = math_ops.reduce_all(
math_ops.logical_not(math_ops.equal(self._diag, 0)))
return control_flow_ops.Assert(
nonzero_diag,
data=["Singular operator: diag contained zero values.", self._diag])
def _assert_positive_definite(self):
return check_ops.assert_positive(
self._diag,
message="Operator was not positive definite: diag was not all positive")
def _apply(self, x, adjoint=False):
# adjoint has no effect since this matrix is self-adjoint.
diag_mat = array_ops.expand_dims(self._diag, -1)
return diag_mat * x
def _determinant(self):
return math_ops.reduce_prod(self._diag, reduction_indices=[-1])
def _log_abs_determinant(self):
return math_ops.reduce_sum(
math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1])
def _solve(self, rhs, adjoint=False):
# adjoint has no effect since this matrix is self-adjoint.
inv_diag_mat = array_ops.expand_dims(1. / self._diag, -1)
return rhs * inv_diag_mat
def _to_dense(self):
return array_ops.matrix_diag(self._diag)
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for testing `LinearOperator` and sub-classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
import tensorflow as tf
@six.add_metaclass(abc.ABCMeta) # pylint: disable=no-init
class LinearOperatorDerivedClassTest(tf.test.TestCase):
"""Tests for derived classes.
Subclasses should implement every abstractmethod, and this will enable all
test methods to work.
"""
@abc.abstractproperty
def _dtypes_to_test(self):
"""Returns list of numpy or tensorflow dtypes. Each will be tested."""
raise NotImplementedError("dtypes_to_test has not been implemented.")
@abc.abstractproperty
def _shapes_to_test(self):
"""Returns list of tuples, each is one shape that will be tested."""
raise NotImplementedError("shapes_to_test has not been implemented.")
@abc.abstractmethod
def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder):
"""Build a batch matrix and an Operator that should have similar behavior.
Every operator acts like a (batch) matrix. This method returns both
together, and is used by tests.
Args:
shape: List-like of Python integers giving full shape of operator.
dtype: Numpy dtype. Data type of returned array/operator.
use_placeholder: Python bool. If True, initialize the operator with a
placeholder of undefined shape and correct dtype.
Returns:
operator: `LinearOperator` subclass instance.
mat: `Tensor` representing operator.
feed_dict: Dictionary. If placholder is True, this must be fed to
sess.run calls at runtime to make the operator work.
"""
# Create a matrix as a numpy array with desired shape/dtype.
# Create a LinearOperator that should have the same behavior as the matrix.
raise NotImplementedError("Not implemented yet.")
@abc.abstractmethod
def _make_rhs(self, operator):
"""Make a rhs appropriate for calling operator.solve(rhs)."""
raise NotImplementedError("_make_rhs is not defined.")
@abc.abstractmethod
def _make_x(self, operator):
"""Make a rhs appropriate for calling operator.apply(rhs)."""
raise NotImplementedError("_make_x is not defined.")
def _maybe_adjoint(self, x, adjoint):
if adjoint:
return tf.matrix_transpose(x)
else:
return x
def test_to_dense(self):
with self.test_session() as sess:
for shape in self._shapes_to_test:
for dtype in self._dtypes_to_test:
operator, mat, _ = self._operator_and_mat_and_feed_dict(
shape, dtype, use_placeholder=False)
op_dense = operator.to_dense()
self.assertAllEqual(shape, op_dense.get_shape())
op_dense_v, mat_v = sess.run([op_dense, mat])
self.assertAllClose(op_dense_v, mat_v)
def test_to_dense_dynamic(self):
with self.test_session() as sess:
for shape in self._shapes_to_test:
for dtype in self._dtypes_to_test:
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
shape, dtype, use_placeholder=True)
op_dense_v, mat_v = sess.run(
[operator.to_dense(), mat], feed_dict=feed_dict)
self.assertAllClose(op_dense_v, mat_v)
def test_det(self):
with self.test_session() as sess:
for shape in self._shapes_to_test:
for dtype in self._dtypes_to_test:
operator, mat, _ = self._operator_and_mat_and_feed_dict(
shape, dtype, use_placeholder=False)
op_det = operator.determinant()
self.assertAllEqual(shape[:-2], op_det.get_shape())
op_det_v, mat_det_v = sess.run([op_det, tf.matrix_determinant(mat)])
self.assertAllClose(op_det_v, mat_det_v)
def test_det_dynamic(self):
with self.test_session() as sess:
for shape in self._shapes_to_test:
for dtype in self._dtypes_to_test:
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
shape, dtype, use_placeholder=True)
op_det_v, mat_det_v = sess.run(
[operator.determinant(), tf.matrix_determinant(mat)],
feed_dict=feed_dict)
self.assertAllClose(op_det_v, mat_det_v)
def test_apply(self):
with self.test_session() as sess:
for shape in self._shapes_to_test:
for dtype in self._dtypes_to_test:
operator, mat, _ = self._operator_and_mat_and_feed_dict(
shape, dtype, use_placeholder=False)
for adjoint in [False, True]:
if adjoint and operator.is_self_adjoint:
continue
x = self._make_x(operator)
op_apply = operator.apply(x, adjoint=adjoint)
mat_apply = tf.batch_matmul(self._maybe_adjoint(mat, adjoint), x)
self.assertAllEqual(op_apply.get_shape(), mat_apply.get_shape())
op_apply_v, mat_apply_v = sess.run([op_apply, mat_apply])
self.assertAllClose(op_apply_v, mat_apply_v)
def test_apply_dynamic(self):
with self.test_session() as sess:
for shape in self._shapes_to_test:
for dtype in self._dtypes_to_test:
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
shape, dtype, use_placeholder=True)
x = self._make_x(operator)
op_apply_v, mat_apply_v = sess.run(
[operator.apply(x), tf.batch_matmul(mat, x)],
feed_dict=feed_dict)
self.assertAllClose(op_apply_v, mat_apply_v)
def test_solve(self):
with self.test_session() as sess:
for shape in self._shapes_to_test:
for dtype in self._dtypes_to_test:
operator, mat, _ = self._operator_and_mat_and_feed_dict(
shape, dtype, use_placeholder=False)
for adjoint in [False, True]:
if adjoint and operator.is_self_adjoint:
continue
rhs = self._make_rhs(operator)
op_solve = operator.solve(rhs, adjoint=adjoint)
mat_solve = tf.matrix_solve(self._maybe_adjoint(mat, adjoint), rhs)
self.assertAllEqual(op_solve.get_shape(), mat_solve.get_shape())
op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve])
self.assertAllClose(op_solve_v, mat_solve_v)
def test_solve_dynamic(self):
with self.test_session() as sess:
for shape in self._shapes_to_test:
for dtype in self._dtypes_to_test:
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
shape, dtype, use_placeholder=True)
rhs = self._make_rhs(operator)
op_solve_v, mat_solve_v = sess.run(
[operator.solve(rhs), tf.matrix_solve(mat, rhs)],
feed_dict=feed_dict)
self.assertAllClose(op_solve_v, mat_solve_v)
......@@ -71,6 +71,7 @@ def module_names():
"tf.contrib.layers",
"tf.contrib.learn",
"tf.contrib.learn.monitors",
"tf.contrib.linalg",
"tf.contrib.losses",
"tf.contrib.metrics",
"tf.contrib.rnn",
......@@ -213,7 +214,7 @@ def all_libraries(module_to_name, members, documented):
"BayesFlow Variational Inference (contrib)",
tf.contrib.bayesflow.variational_inference),
library("contrib.crf", "CRF (contrib)", tf.contrib.crf),
library("contrib.distributions", "Statistical distributions (contrib)",
library("contrib.distributions", "Statistical Distributions (contrib)",
tf.contrib.distributions),
library("contrib.distributions.bijector",
"Random variable transformations (contrib)",
......@@ -227,6 +228,8 @@ def all_libraries(module_to_name, members, documented):
library("contrib.learn", "Learn (contrib)", tf.contrib.learn),
library("contrib.learn.monitors", "Monitors (contrib)",
tf.contrib.learn.monitors),
library("contrib.linalg", "Linear Algebra (contrib)",
tf.contrib.linalg),
library("contrib.losses", "Losses (contrib)", tf.contrib.losses),
library("contrib.rnn", "RNN (contrib)", tf.contrib.rnn),
library("contrib.metrics", "Metrics (contrib)", tf.contrib.metrics),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册