提交 715f951e 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Adding a TF QR op.

Change: 139959769
上级 dc4b868b
......@@ -151,6 +151,10 @@ if (tensorflow_BUILD_PYTHON_TESTS)
# missing kernel
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/conv_ops_test.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/depthwise_conv_op_test.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/pool_test.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/qr_op_test.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/svd_op_test.py"
# cuda launch failed
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/diag_op_test.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/trace_op_test.py"
......
......@@ -1324,6 +1324,7 @@ tf_kernel_libraries(
"matrix_solve_ls_op",
"matrix_solve_op",
"matrix_triangular_solve_op",
"qr_op",
"svd_op",
],
deps = [
......
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/qr_op_impl.h"
namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<complex128>), complex128);
} // namespace tensorflow
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/qr_op_impl.h"
namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<complex64>), complex64);
} // namespace tensorflow
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/qr_op_impl.h"
namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<double>), double);
} // namespace tensorflow
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/qr_op_impl.h"
namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<float>), float);
} // namespace tensorflow
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/linalg_ops.cc.
//
// This header file is used by the individual qr_*op*.cc files for registering
// individual kernels. A separate file is used for each instantiated kernel to
// improve compilation times.
#include <algorithm>
#include "third_party/eigen3/Eigen/QR"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
template <class Scalar>
class QrOp : public LinearAlgebraOp<Scalar> {
public:
typedef LinearAlgebraOp<Scalar> Base;
explicit QrOp(OpKernelConstruction* context) : Base(context) {
OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_));
}
using TensorShapes = typename Base::TensorShapes;
void ValidateInputMatrixShapes(
OpKernelContext* context,
const TensorShapes& input_matrix_shapes) const final {
Base::ValidateSingleMatrix(context, input_matrix_shapes);
}
TensorShapes GetOutputMatrixShapes(
const TensorShapes& input_matrix_shapes) const final {
int64 m = input_matrix_shapes[0].dim_size(0);
int64 n = input_matrix_shapes[0].dim_size(1);
int64 min_size = std::min(m, n);
if (full_matrices_) {
return TensorShapes({TensorShape({m, m}), TensorShape({m, n})});
} else {
return TensorShapes(
{TensorShape({m, min_size}), TensorShape({min_size, n})});
}
}
int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
double m = static_cast<double>(input_matrix_shapes[0].dim_size(0));
double n = static_cast<double>(input_matrix_shapes[0].dim_size(1));
double max_size = std::max(m, n);
double min_size = std::min(m, n);
double cost = 2 * max_size * min_size * min_size -
2 * min_size * min_size * min_size / 3.;
// TODO(jpoulson): Increase the cost if full_matrices is true in a manner
// that reflects the algorithm used for the expansion.
return cost >= static_cast<double>(kint64max) ? kint64max
: static_cast<int64>(cost);
}
using Matrix = typename Base::Matrix;
using MatrixMaps = typename Base::MatrixMaps;
using ConstMatrixMap = typename Base::ConstMatrixMap;
using ConstMatrixMaps = typename Base::ConstMatrixMaps;
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
MatrixMaps* outputs) final {
Eigen::HouseholderQR<Matrix> qr(inputs[0]);
const int m = inputs[0].rows();
const int n = inputs[0].cols();
const int min_size = std::min(m, n);
if (full_matrices_) {
outputs->at(0) = qr.householderQ();
outputs->at(1) = qr.matrixQR().template triangularView<Eigen::Upper>();
} else {
// TODO(jpoulson): Exploit the fact that Householder transformations can
// be expanded faster than they can be applied to an arbitrary matrix
// (Cf. LAPACK's DORGQR).
Matrix tmp = Matrix::Identity(m, min_size);
outputs->at(0) = qr.householderQ() * tmp;
auto qr_top = qr.matrixQR().block(0, 0, min_size, n);
outputs->at(1) = qr_top.template triangularView<Eigen::Upper>();
}
}
private:
bool full_matrices_;
TF_DISALLOW_COPY_AND_ASSIGN(QrOp);
};
} // namespace tensorflow
......@@ -109,6 +109,36 @@ Status SelfAdjointEigV2ShapeFn(InferenceContext* c) {
return Status::OK();
}
// Input is [...,M,N].
// First and second outputs are:
// [...,M,M]; [...,M,N], if full_matrices is true,
// [...,M,P]; [...,P,N], if full_matrices is false,
// where P = min(M,N).
Status QrShapeFn(InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
DimensionHandle m = c->Dim(input, -2);
DimensionHandle n = c->Dim(input, -1);
DimensionHandle p;
TF_RETURN_IF_ERROR(c->Min(m, n, &p));
ShapeHandle batch_shape;
TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
ShapeHandle q_shape;
ShapeHandle r_shape;
bool full_matrices;
TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
if (full_matrices) {
TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, m), &q_shape));
TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, n), &r_shape));
} else {
TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, p), &q_shape));
TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(p, n), &r_shape));
}
c->set_output(0, q_shape);
c->set_output(1, r_shape);
return Status::OK();
}
// Input is [...,M,N]. First output is [...,min(M,N)].
// Second and third outputs are:
// [0]; [0], if compute_uv is false.
......@@ -435,6 +465,38 @@ Equivalent to np.linalg.lstsq
@end_compatibility
)doc");
REGISTER_OP("Qr")
.Input("input: T")
.Output("q: T")
.Output("r: T")
.Attr("full_matrices: bool = False")
.Attr("T: {double, float, complex64, complex128}")
.SetShapeFn(QrShapeFn)
.Doc(R"doc(
Computes the QR decompositions of one or more matrices.
Computes the QR decomposition of each inner matrix in `tensor` such that
`tensor[..., :, :] = q[..., :, :] * r[..., :,:])`
```prettyprint
# a is a tensor.
# q is a tensor of orthonormal matrices.
# r is a tensor of upper triangular matrices.
q, r = qr(a)
q_full, r_full = qr(a, full_matrices=True)
```
input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions
form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`.
q: Orthonormal basis for range of `a`. If `full_matrices` is `False` then
shape is `[..., M, P]`; if `full_matrices` is `True` then shape is
`[..., M, M]`.
r: Triangular factor. If `full_matrices` is `False` then shape is
`[..., P, N]`. If `full_matrices` is `True` then shape is `[..., M, N]`.
full_matrices: If true, compute full-sized `q` and `r`. If false
(the default), compute only the leading `P` columns of `q`.
)doc");
REGISTER_OP("Svd")
.Input("input: T")
.Output("s: T")
......@@ -463,10 +525,10 @@ input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions
form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`.
s: Singular values. Shape is `[..., P]`.
u: Left singular vectors. If `full_matrices` is `False` then shape is
`[..., M, M]`; if `full_matrices` is `True` then shape is
`[..., M, P]`. Undefined if `compute_uv` is `False`.
`[..., M, P]`; if `full_matrices` is `True` then shape is
`[..., M, M]`. Undefined if `compute_uv` is `False`.
v: Left singular vectors. If `full_matrices` is `False` then shape is
`[..., N, N]`. If `full_matrices` is `True` then shape is `[..., N, P]`.
`[..., N, P]`. If `full_matrices` is `True` then shape is `[..., N, N]`.
Undefined if `compute_uv` is false.
compute_uv: If true, left and right singular vectors will be
computed and returned in `u` and `v`, respectively.
......
......@@ -171,6 +171,50 @@ TEST(LinalgOpsTest, MatrixSolveLs_ShapeFn) {
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;?;[1]");
}
TEST(LinalgOpsTest, Qr_ShapeFn) {
ShapeInferenceTestOp op("Qr");
auto set_attrs = [&op](bool full_matrices) {
TF_ASSERT_OK(NodeDefBuilder("test", "Qr")
.Input({"input", 0, DT_FLOAT})
.Attr("full_matrices", full_matrices)
.Finalize(&op.node_def));
};
// Defining `P` = min(`M`, `N`), if full_matrices = False, then Q should be
// `M` x `P` and `R` should be `P` x `N`. Otherwise, Q should be
// `M` x `M` and `R` should be `M` x `N`.
//
// For rank-3 tensors, `M` = d0_1 and `N` = d0_2.
//
set_attrs(false);
INFER_OK(op, "?", "?;?");
INFER_OK(op, "[?,?,?]", "[d0_0,d0_1,?];[d0_0,?,d0_2]");
INFER_OK(op, "[4,?,?]", "[d0_0,d0_1,?];[d0_0,?,d0_2]");
INFER_OK(op, "[4,2,?]", "[d0_0,d0_1,?];[d0_0,?,d0_2]");
INFER_OK(op, "[4,?,2]", "[d0_0,d0_1,?];[d0_0,?,d0_2]");
INFER_OK(op, "[?,2,2]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[4,2,2]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[?,3,2]", "[d0_0,d0_1,d0_2];[d0_0,d0_2,d0_2]");
INFER_OK(op, "[4,3,2]", "[d0_0,d0_1,d0_2];[d0_0,d0_2,d0_2]");
INFER_OK(op, "[?,2,3]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[4,2,3]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
set_attrs(true);
INFER_OK(op, "?", "?;?");
INFER_OK(op, "[?,?,?]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[4,?,?]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[4,2,?]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[4,?,2]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[?,2,2]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[4,2,2]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[?,3,2]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[4,3,2]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[?,2,3]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[4,2,3]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
}
TEST(LinalgOpsTest, Svd_ShapeFn) {
ShapeInferenceTestOp op("Svd");
auto set_attrs = [&op](bool compute_uv, bool full_matrices) {
......@@ -180,6 +224,13 @@ TEST(LinalgOpsTest, Svd_ShapeFn) {
.Attr("full_matrices", full_matrices)
.Finalize(&op.node_def));
};
// Defining `P` = min(`M`, `N`), if full_matrices = False, then U should be
// `M` x `P` and `V` should be `N` x `P`. Otherwise, U should be
// `M` x `M` and `V` should be `N` x `N`.
//
// For rank-3 tensors, `M` = d0_1 and `N` = d0_2.
//
set_attrs(false, false);
INFER_OK(op, "?", "?;[0];[0]");
INFER_OK(op, "[?,?,?]", "[d0_0,?];[0];[0]");
......
......@@ -1030,8 +1030,10 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":array_ops",
":control_flow_ops_gen",
":data_flow_ops_gen",
":framework",
":framework_for_generated_wrappers",
":math_ops_gen",
":sparse_ops_gen",
":state_ops",
......
......@@ -1355,6 +1355,15 @@ cuda_py_test(
shard_count = 20,
)
cuda_py_test(
name = "qr_op_test",
size = "medium",
srcs = ["qr_op_test.py"],
additional_deps = ["//tensorflow:tensorflow_py"],
shard_count = 20,
tags = ["nomsan"], # fails in msan from numpy calls
)
cuda_py_test(
name = "svd_op_test",
size = "medium",
......
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.math_ops.matrix_inverse."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
class QrOpTest(tf.test.TestCase):
def testWrongDimensions(self):
# The input to qr should be a tensor of at least rank 2.
scalar = tf.constant(1.)
with self.assertRaisesRegexp(ValueError,
"Shape must be at least rank 2 but is rank 0"):
tf.qr(scalar)
vector = tf.constant([1., 2.])
with self.assertRaisesRegexp(ValueError,
"Shape must be at least rank 2 but is rank 1"):
tf.qr(vector)
def _GetQrOpTest(dtype_, shape_, use_static_shape_):
is_complex = dtype_ in (np.complex64, np.complex128)
is_single = dtype_ in (np.float32, np.complex64)
def CompareOrthogonal(self, x, y, rank):
if is_single:
atol = 5e-4
else:
atol = 5e-14
# We only compare the first 'rank' orthogonal vectors since the
# remainder form an arbitrary orthonormal basis for the
# (row- or column-) null space, whose exact value depends on
# implementation details. Notice that since we check that the
# matrices of singular vectors are unitary elsewhere, we do
# implicitly test that the trailing vectors of x and y span the
# same space.
x = x[..., 0:rank]
y = y[..., 0:rank]
# Q is only unique up to sign (complex phase factor for complex matrices),
# so we normalize the sign first.
sum_of_ratios = np.sum(np.divide(y, x), -2, keepdims=True)
phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
x *= phases
self.assertAllClose(x, y, atol=atol)
def CheckApproximation(self, a, q, r):
if is_single:
tol = 1e-5
else:
tol = 1e-14
# Tests that a ~= q*r.
a_recon = tf.matmul(q, r)
self.assertAllClose(a_recon.eval(), a, rtol=tol, atol=tol)
def CheckUnitary(self, x):
# Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
xx = tf.matmul(tf.conj(x), x, transpose_a=True)
identity = tf.matrix_band_part(tf.ones_like(xx), 0, 0)
if is_single:
tol = 1e-5
else:
tol = 1e-14
self.assertAllClose(identity.eval(), xx.eval(), atol=tol)
def Test(self):
np.random.seed(1)
x_np = np.random.uniform(
low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_)
if is_complex:
x_np += 1j * np.random.uniform(
low=-1.0, high=1.0,
size=np.prod(shape_)).reshape(shape_).astype(dtype_)
for full_matrices in False, True:
with self.test_session() as sess:
if use_static_shape_:
x_tf = tf.constant(x_np)
else:
x_tf = tf.placeholder(dtype_)
q_tf, r_tf = tf.qr(x_tf, full_matrices=full_matrices)
if use_static_shape_:
q_tf_val, r_tf_val = sess.run([q_tf, r_tf])
else:
q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
q_dims = q_tf_val.shape
np_q = np.ndarray(q_dims, dtype_)
np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1]))
new_first_dim = np_q_reshape.shape[0]
x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1]))
for i in range(new_first_dim):
if full_matrices:
np_q_reshape[i,:,:], _ = \
np.linalg.qr(x_reshape[i,:,:], mode="complete")
else:
np_q_reshape[i,:,:], _ = \
np.linalg.qr(x_reshape[i,:,:], mode="reduced")
np_q = np.reshape(np_q_reshape, q_dims)
CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:]))
CheckApproximation(self, x_np, q_tf_val, r_tf_val)
CheckUnitary(self, q_tf_val)
return Test
if __name__ == "__main__":
for dtype in np.float32, np.float64, np.complex64, np.complex128:
for rows in 1, 2, 5, 10, 32, 100:
for cols in 1, 2, 5, 10, 32, 100:
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
shape = batch_dims + (rows, cols)
for use_static_shape in True, False:
name = "%s_%s_%s" % (dtype.__name__, "_".join(map(str, shape)),
use_static_shape)
setattr(QrOpTest, "testQr_" + name,
_GetQrOpTest(dtype, shape, use_static_shape))
tf.test.main()
......@@ -24,7 +24,7 @@ import tensorflow as tf
class SvdOpTest(tf.test.TestCase):
def testWrongDimensions(self):
# The input to batch_svd should be a tensor of at least rank 2.
# The input to svd should be a tensor of at least rank 2.
scalar = tf.constant(1.)
with self.assertRaisesRegexp(ValueError,
"Shape must be at least rank 2 but is rank 0"):
......@@ -35,7 +35,7 @@ class SvdOpTest(tf.test.TestCase):
tf.svd(vector)
def _GetSvdOpTest(dtype_, shape_):
def _GetSvdOpTest(dtype_, shape_, use_static_shape_):
is_complex = dtype_ in (np.complex64, np.complex128)
is_single = dtype_ in (np.float32, np.complex64)
......@@ -101,40 +101,61 @@ def _GetSvdOpTest(dtype_, shape_):
def Test(self):
np.random.seed(1)
x = np.random.uniform(
x_np = np.random.uniform(
low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_)
if is_complex:
x += 1j * np.random.uniform(
x_np += 1j * np.random.uniform(
low=-1.0, high=1.0,
size=np.prod(shape_)).reshape(shape_).astype(dtype_)
for compute_uv in False, True:
for full_matrices in False, True:
with self.test_session():
with self.test_session() as sess:
if use_static_shape_:
x_tf = tf.constant(x_np)
else:
x_tf = tf.placeholder(dtype_)
if compute_uv:
tf_s, tf_u, tf_v = tf.svd(tf.constant(x),
s_tf, u_tf, v_tf = tf.svd(x_tf,
compute_uv=compute_uv,
full_matrices=full_matrices)
if use_static_shape_:
s_tf_val, u_tf_val, v_tf_val = sess.run([s_tf, u_tf, v_tf])
else:
s_tf_val, u_tf_val, v_tf_val = sess.run([s_tf, u_tf, v_tf],
feed_dict={x_tf: x_np})
else:
tf_s = tf.svd(tf.constant(x),
s_tf = tf.svd(x_tf,
compute_uv=compute_uv,
full_matrices=full_matrices)
if use_static_shape_:
s_tf_val = sess.run(s_tf)
else:
s_tf_val = sess.run(s_tf, feed_dict={x_tf: x_np})
if compute_uv:
np_u, np_s, np_v = np.linalg.svd(x,
u_np, s_np, v_np = np.linalg.svd(x_np,
compute_uv=compute_uv,
full_matrices=full_matrices)
else:
np_s = np.linalg.svd(x,
s_np = np.linalg.svd(x_np,
compute_uv=compute_uv,
full_matrices=full_matrices)
CompareSingularValues(self, np_s, tf_s.eval())
# We explicitly avoid the situation where numpy eliminates a first
# dimension that is equal to one
s_np = np.reshape(s_np, s_tf_val.shape)
CompareSingularValues(self, s_np, s_tf_val)
if compute_uv:
CompareSingularVectors(self, np_u, tf_u.eval(), min(shape_[-2:]))
CompareSingularVectors(self, np.conj(np.swapaxes(np_v, -2, -1)),
tf_v.eval(), min(shape_[-2:]))
CheckApproximation(self, x, tf_u, tf_s, tf_v, full_matrices)
CheckUnitary(self, tf_u)
CheckUnitary(self, tf_v)
CompareSingularVectors(self, u_np, u_tf_val, min(shape_[-2:]))
CompareSingularVectors(self,
np.conj(np.swapaxes(v_np, -2, -1)), v_tf_val,
min(shape_[-2:]))
CheckApproximation(self, x_np, u_tf_val, s_tf_val, v_tf_val,
full_matrices)
CheckUnitary(self, u_tf_val)
CheckUnitary(self, v_tf_val)
return Test
......@@ -145,6 +166,9 @@ if __name__ == "__main__":
for cols in 1, 2, 5, 10, 32, 100:
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
shape = batch_dims + (rows, cols)
name = "%s_%s" % (dtype.__name__, "_".join(map(str, shape)))
setattr(SvdOpTest, "testSvd_" + name, _GetSvdOpTest(dtype, shape))
for use_static_shape in True, False:
name = "%s_%s_%s" % (dtype.__name__, "_".join(map(str, shape)),
use_static_shape)
setattr(SvdOpTest, "testSvd_" + name,
_GetSvdOpTest(dtype, shape, use_static_shape))
tf.test.main()
......@@ -104,6 +104,7 @@ functions on matrices to your graph.
@@matrix_solve
@@matrix_triangular_solve
@@matrix_solve_ls
@@qr
@@self_adjoint_eig
@@self_adjoint_eigvals
@@svd
......@@ -949,6 +950,7 @@ def div(x, y, name=None):
def div_deprecated(x, y, name=None):
return gen_math_ops.div(x, y, name)
mod = gen_math_ops.floor_mod
......@@ -1001,6 +1003,7 @@ def floordiv_deprecated(x, y, name=None):
# return gen_math_ops.floor_div(x, y, name=name)
return gen_math_ops.div(x, y, name=name)
realdiv = gen_math_ops.real_div
truncatediv = gen_math_ops.truncate_div
# TODO(aselle): Rename this to floordiv when we can.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册