提交 715f951e 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Adding a TF QR op.

Change: 139959769
上级 dc4b868b
......@@ -4,16 +4,16 @@ enable_testing()
# get a temp path for test data
#
function(GetTestRunPath VAR_NAME OBJ_NAME)
if(WIN32)
if(DEFINED ENV{TMP})
if(WIN32)
if(DEFINED ENV{TMP})
set(TMPDIR "$ENV{TMP}")
elseif(DEFINED ENV{TEMP})
set(TMPDIR "$ENV{TEMP}")
endif()
string(REPLACE "\\" "/" TMPDIR ${TMPDIR})
else()
set(TMPDIR "$ENV{TMPDIR}")
endif()
else()
set(TMPDIR "$ENV{TMPDIR}")
endif()
if(NOT EXISTS "${TMPDIR}")
message(FATAL_ERROR "Unable to determine a path to the temporary directory")
endif()
......@@ -45,7 +45,7 @@ endfunction(AddTests)
#
function(AddTest)
cmake_parse_arguments(_AT "" "TARGET" "SOURCES;OBJECTS;LIBS;DATA;DEPENDS" ${ARGN})
list(REMOVE_DUPLICATES _AT_SOURCES)
list(REMOVE_DUPLICATES _AT_OBJECTS)
list(REMOVE_DUPLICATES _AT_LIBS)
......@@ -55,7 +55,7 @@ function(AddTest)
if (_AT_DEPENDS)
list(REMOVE_DUPLICATES _AT_DEPENDS)
endif(_AT_DEPENDS)
add_executable(${_AT_TARGET} ${_AT_SOURCES} ${_AT_OBJECTS})
target_link_libraries(${_AT_TARGET} ${_AT_LIBS})
......@@ -96,7 +96,7 @@ function(AddPythonTests)
if (_AT_DEPENDS)
list(REMOVE_DUPLICATES _AT_DEPENDS)
endif(_AT_DEPENDS)
foreach(sourcefile ${_AT_SOURCES})
add_test(NAME ${sourcefile} COMMAND ${PYTHON_EXECUTABLE} ${sourcefile})
if (_AT_DEPENDS)
......@@ -108,11 +108,11 @@ endfunction(AddPythonTests)
if (tensorflow_BUILD_PYTHON_TESTS)
#
# python tests. This assumes that the tensorflow wheel is
# installed on the test system.
# installed on the test system.
# TODO: we currently don't handle tests that need to have
# some environment setup: see AddTest how to add this
#
# include all test
file(GLOB_RECURSE tf_test_src_py
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/*.py"
......@@ -124,14 +124,14 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/__init__.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/resource_variable_ops_test.py"
)
)
if (WIN32)
set(tf_test_src_py_exclude
${tf_test_src_py_exclude}
# generally excluded
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/__init__.py"
# TODO: failing tests.
# TODO: failing tests.
# Nothing critical in here but should get this list down to []
# The failing list is grouped by failure source
# stl on windows handles overflows different
......@@ -148,9 +148,13 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/py_func_test.py"
# issues related to windows fs
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/io_ops_test.py"
# missing kernel
# missing kernel
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/conv_ops_test.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/depthwise_conv_op_test.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/pool_test.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/qr_op_test.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/svd_op_test.py"
# cuda launch failed
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/diag_op_test.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/trace_op_test.py"
......@@ -158,10 +162,10 @@ if (tensorflow_BUILD_PYTHON_TESTS)
)
endif()
list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude})
AddPythonTests(
SOURCES ${tf_test_src_py}
)
)
endif(tensorflow_BUILD_PYTHON_TESTS)
if (tensorflow_BUILD_CC_TESTS)
......@@ -169,9 +173,9 @@ if (tensorflow_BUILD_CC_TESTS)
# cc unit tests. Be aware that by default we include 250+ tests which
# will take time and space to build.
# If you wan to cut this down, for example to a specific test, modify
# tf_test_src_simple to your needs
# tf_test_src_simple to your needs
#
include_directories(${googletest_INCLUDE_DIRS})
# cc tests wrapper
......@@ -228,7 +232,7 @@ if (tensorflow_BUILD_CC_TESTS)
# generally excluded
"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc"
"${tensorflow_source_dir}/tensorflow/cc/framework/cc_ops_test.cc" # test_op.h missing
# TODO: test failing
"${tensorflow_source_dir}/tensorflow/core/common_runtime/simple_placer_test.cc"
"${tensorflow_source_dir}/tensorflow/core/distributed_runtime/executor_test.cc"
......@@ -254,7 +258,7 @@ if (tensorflow_BUILD_CC_TESTS)
"${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops_test.cc" # status 5
"${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops_test.cc" # status 5
# TODO: not compiling
# TODO: not compiling
"${tensorflow_source_dir}/tensorflow/cc/framework/gradient_checker_test.cc"
"${tensorflow_source_dir}/tensorflow/cc/gradients/math_grad_test.cc"
"${tensorflow_source_dir}/tensorflow/cc/gradients/array_grad_test.cc"
......@@ -344,13 +348,13 @@ if (tensorflow_BUILD_CC_TESTS)
endif()
list(REMOVE_ITEM tf_test_src_simple ${tf_test_src_simple_exclude})
set(tf_test_lib tf_test_lib)
add_library(${tf_test_lib} STATIC ${tf_src_testlib})
# this is giving to much objects and libraries to the linker but
# this is giving to much objects and libraries to the linker but
# it makes this script much easier. So for now we do it this way.
set(tf_obj_test
set(tf_obj_test
$<TARGET_OBJECTS:tf_core_lib>
$<TARGET_OBJECTS:tf_core_cpu>
$<TARGET_OBJECTS:tf_core_framework>
......@@ -362,10 +366,10 @@ if (tensorflow_BUILD_CC_TESTS)
$<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_stream_executor>>
)
set(tf_test_libs
set(tf_test_libs
tf_protos_cc
tf_test_lib
${tf_core_gpu_kernels_lib}
${tf_core_gpu_kernels_lib}
${googletest_STATIC_LIBRARIES}
${tensorflow_EXTERNAL_LIBRARIES}
)
......@@ -373,7 +377,7 @@ if (tensorflow_BUILD_CC_TESTS)
AddTests(
SOURCES ${tf_test_src_simple}
OBJECTS ${tf_obj_test}
LIBS ${tf_test_libs}
LIBS ${tf_test_libs}
DEPENDS googletest
)
endif(tensorflow_BUILD_CC_TESTS)
......@@ -1324,6 +1324,7 @@ tf_kernel_libraries(
"matrix_solve_ls_op",
"matrix_solve_op",
"matrix_triangular_solve_op",
"qr_op",
"svd_op",
],
deps = [
......
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/qr_op_impl.h"
namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<complex128>), complex128);
} // namespace tensorflow
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/qr_op_impl.h"
namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<complex64>), complex64);
} // namespace tensorflow
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/qr_op_impl.h"
namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<double>), double);
} // namespace tensorflow
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/qr_op_impl.h"
namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<float>), float);
} // namespace tensorflow
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/linalg_ops.cc.
//
// This header file is used by the individual qr_*op*.cc files for registering
// individual kernels. A separate file is used for each instantiated kernel to
// improve compilation times.
#include <algorithm>
#include "third_party/eigen3/Eigen/QR"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
template <class Scalar>
class QrOp : public LinearAlgebraOp<Scalar> {
public:
typedef LinearAlgebraOp<Scalar> Base;
explicit QrOp(OpKernelConstruction* context) : Base(context) {
OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_));
}
using TensorShapes = typename Base::TensorShapes;
void ValidateInputMatrixShapes(
OpKernelContext* context,
const TensorShapes& input_matrix_shapes) const final {
Base::ValidateSingleMatrix(context, input_matrix_shapes);
}
TensorShapes GetOutputMatrixShapes(
const TensorShapes& input_matrix_shapes) const final {
int64 m = input_matrix_shapes[0].dim_size(0);
int64 n = input_matrix_shapes[0].dim_size(1);
int64 min_size = std::min(m, n);
if (full_matrices_) {
return TensorShapes({TensorShape({m, m}), TensorShape({m, n})});
} else {
return TensorShapes(
{TensorShape({m, min_size}), TensorShape({min_size, n})});
}
}
int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
double m = static_cast<double>(input_matrix_shapes[0].dim_size(0));
double n = static_cast<double>(input_matrix_shapes[0].dim_size(1));
double max_size = std::max(m, n);
double min_size = std::min(m, n);
double cost = 2 * max_size * min_size * min_size -
2 * min_size * min_size * min_size / 3.;
// TODO(jpoulson): Increase the cost if full_matrices is true in a manner
// that reflects the algorithm used for the expansion.
return cost >= static_cast<double>(kint64max) ? kint64max
: static_cast<int64>(cost);
}
using Matrix = typename Base::Matrix;
using MatrixMaps = typename Base::MatrixMaps;
using ConstMatrixMap = typename Base::ConstMatrixMap;
using ConstMatrixMaps = typename Base::ConstMatrixMaps;
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
MatrixMaps* outputs) final {
Eigen::HouseholderQR<Matrix> qr(inputs[0]);
const int m = inputs[0].rows();
const int n = inputs[0].cols();
const int min_size = std::min(m, n);
if (full_matrices_) {
outputs->at(0) = qr.householderQ();
outputs->at(1) = qr.matrixQR().template triangularView<Eigen::Upper>();
} else {
// TODO(jpoulson): Exploit the fact that Householder transformations can
// be expanded faster than they can be applied to an arbitrary matrix
// (Cf. LAPACK's DORGQR).
Matrix tmp = Matrix::Identity(m, min_size);
outputs->at(0) = qr.householderQ() * tmp;
auto qr_top = qr.matrixQR().block(0, 0, min_size, n);
outputs->at(1) = qr_top.template triangularView<Eigen::Upper>();
}
}
private:
bool full_matrices_;
TF_DISALLOW_COPY_AND_ASSIGN(QrOp);
};
} // namespace tensorflow
......@@ -109,6 +109,36 @@ Status SelfAdjointEigV2ShapeFn(InferenceContext* c) {
return Status::OK();
}
// Input is [...,M,N].
// First and second outputs are:
// [...,M,M]; [...,M,N], if full_matrices is true,
// [...,M,P]; [...,P,N], if full_matrices is false,
// where P = min(M,N).
Status QrShapeFn(InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
DimensionHandle m = c->Dim(input, -2);
DimensionHandle n = c->Dim(input, -1);
DimensionHandle p;
TF_RETURN_IF_ERROR(c->Min(m, n, &p));
ShapeHandle batch_shape;
TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
ShapeHandle q_shape;
ShapeHandle r_shape;
bool full_matrices;
TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
if (full_matrices) {
TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, m), &q_shape));
TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, n), &r_shape));
} else {
TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, p), &q_shape));
TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(p, n), &r_shape));
}
c->set_output(0, q_shape);
c->set_output(1, r_shape);
return Status::OK();
}
// Input is [...,M,N]. First output is [...,min(M,N)].
// Second and third outputs are:
// [0]; [0], if compute_uv is false.
......@@ -435,6 +465,38 @@ Equivalent to np.linalg.lstsq
@end_compatibility
)doc");
REGISTER_OP("Qr")
.Input("input: T")
.Output("q: T")
.Output("r: T")
.Attr("full_matrices: bool = False")
.Attr("T: {double, float, complex64, complex128}")
.SetShapeFn(QrShapeFn)
.Doc(R"doc(
Computes the QR decompositions of one or more matrices.
Computes the QR decomposition of each inner matrix in `tensor` such that
`tensor[..., :, :] = q[..., :, :] * r[..., :,:])`
```prettyprint
# a is a tensor.
# q is a tensor of orthonormal matrices.
# r is a tensor of upper triangular matrices.
q, r = qr(a)
q_full, r_full = qr(a, full_matrices=True)
```
input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions
form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`.
q: Orthonormal basis for range of `a`. If `full_matrices` is `False` then
shape is `[..., M, P]`; if `full_matrices` is `True` then shape is
`[..., M, M]`.
r: Triangular factor. If `full_matrices` is `False` then shape is
`[..., P, N]`. If `full_matrices` is `True` then shape is `[..., M, N]`.
full_matrices: If true, compute full-sized `q` and `r`. If false
(the default), compute only the leading `P` columns of `q`.
)doc");
REGISTER_OP("Svd")
.Input("input: T")
.Output("s: T")
......@@ -463,10 +525,10 @@ input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions
form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`.
s: Singular values. Shape is `[..., P]`.
u: Left singular vectors. If `full_matrices` is `False` then shape is
`[..., M, M]`; if `full_matrices` is `True` then shape is
`[..., M, P]`. Undefined if `compute_uv` is `False`.
`[..., M, P]`; if `full_matrices` is `True` then shape is
`[..., M, M]`. Undefined if `compute_uv` is `False`.
v: Left singular vectors. If `full_matrices` is `False` then shape is
`[..., N, N]`. If `full_matrices` is `True` then shape is `[..., N, P]`.
`[..., N, P]`. If `full_matrices` is `True` then shape is `[..., N, N]`.
Undefined if `compute_uv` is false.
compute_uv: If true, left and right singular vectors will be
computed and returned in `u` and `v`, respectively.
......
......@@ -171,6 +171,50 @@ TEST(LinalgOpsTest, MatrixSolveLs_ShapeFn) {
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;?;[1]");
}
TEST(LinalgOpsTest, Qr_ShapeFn) {
ShapeInferenceTestOp op("Qr");
auto set_attrs = [&op](bool full_matrices) {
TF_ASSERT_OK(NodeDefBuilder("test", "Qr")
.Input({"input", 0, DT_FLOAT})
.Attr("full_matrices", full_matrices)
.Finalize(&op.node_def));
};
// Defining `P` = min(`M`, `N`), if full_matrices = False, then Q should be
// `M` x `P` and `R` should be `P` x `N`. Otherwise, Q should be
// `M` x `M` and `R` should be `M` x `N`.
//
// For rank-3 tensors, `M` = d0_1 and `N` = d0_2.
//
set_attrs(false);
INFER_OK(op, "?", "?;?");
INFER_OK(op, "[?,?,?]", "[d0_0,d0_1,?];[d0_0,?,d0_2]");
INFER_OK(op, "[4,?,?]", "[d0_0,d0_1,?];[d0_0,?,d0_2]");
INFER_OK(op, "[4,2,?]", "[d0_0,d0_1,?];[d0_0,?,d0_2]");
INFER_OK(op, "[4,?,2]", "[d0_0,d0_1,?];[d0_0,?,d0_2]");
INFER_OK(op, "[?,2,2]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[4,2,2]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[?,3,2]", "[d0_0,d0_1,d0_2];[d0_0,d0_2,d0_2]");
INFER_OK(op, "[4,3,2]", "[d0_0,d0_1,d0_2];[d0_0,d0_2,d0_2]");
INFER_OK(op, "[?,2,3]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[4,2,3]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
set_attrs(true);
INFER_OK(op, "?", "?;?");
INFER_OK(op, "[?,?,?]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[4,?,?]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[4,2,?]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[4,?,2]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[?,2,2]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[4,2,2]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[?,3,2]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[4,3,2]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[?,2,3]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_OK(op, "[4,2,3]", "[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
}
TEST(LinalgOpsTest, Svd_ShapeFn) {
ShapeInferenceTestOp op("Svd");
auto set_attrs = [&op](bool compute_uv, bool full_matrices) {
......@@ -180,6 +224,13 @@ TEST(LinalgOpsTest, Svd_ShapeFn) {
.Attr("full_matrices", full_matrices)
.Finalize(&op.node_def));
};
// Defining `P` = min(`M`, `N`), if full_matrices = False, then U should be
// `M` x `P` and `V` should be `N` x `P`. Otherwise, U should be
// `M` x `M` and `V` should be `N` x `N`.
//
// For rank-3 tensors, `M` = d0_1 and `N` = d0_2.
//
set_attrs(false, false);
INFER_OK(op, "?", "?;[0];[0]");
INFER_OK(op, "[?,?,?]", "[d0_0,?];[0];[0]");
......
......@@ -1030,8 +1030,10 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":array_ops",
":control_flow_ops_gen",
":data_flow_ops_gen",
":framework",
":framework_for_generated_wrappers",
":math_ops_gen",
":sparse_ops_gen",
":state_ops",
......
......@@ -1355,6 +1355,15 @@ cuda_py_test(
shard_count = 20,
)
cuda_py_test(
name = "qr_op_test",
size = "medium",
srcs = ["qr_op_test.py"],
additional_deps = ["//tensorflow:tensorflow_py"],
shard_count = 20,
tags = ["nomsan"], # fails in msan from numpy calls
)
cuda_py_test(
name = "svd_op_test",
size = "medium",
......
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.math_ops.matrix_inverse."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
class QrOpTest(tf.test.TestCase):
def testWrongDimensions(self):
# The input to qr should be a tensor of at least rank 2.
scalar = tf.constant(1.)
with self.assertRaisesRegexp(ValueError,
"Shape must be at least rank 2 but is rank 0"):
tf.qr(scalar)
vector = tf.constant([1., 2.])
with self.assertRaisesRegexp(ValueError,
"Shape must be at least rank 2 but is rank 1"):
tf.qr(vector)
def _GetQrOpTest(dtype_, shape_, use_static_shape_):
is_complex = dtype_ in (np.complex64, np.complex128)
is_single = dtype_ in (np.float32, np.complex64)
def CompareOrthogonal(self, x, y, rank):
if is_single:
atol = 5e-4
else:
atol = 5e-14
# We only compare the first 'rank' orthogonal vectors since the
# remainder form an arbitrary orthonormal basis for the
# (row- or column-) null space, whose exact value depends on
# implementation details. Notice that since we check that the
# matrices of singular vectors are unitary elsewhere, we do
# implicitly test that the trailing vectors of x and y span the
# same space.
x = x[..., 0:rank]
y = y[..., 0:rank]
# Q is only unique up to sign (complex phase factor for complex matrices),
# so we normalize the sign first.
sum_of_ratios = np.sum(np.divide(y, x), -2, keepdims=True)
phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
x *= phases
self.assertAllClose(x, y, atol=atol)
def CheckApproximation(self, a, q, r):
if is_single:
tol = 1e-5
else:
tol = 1e-14
# Tests that a ~= q*r.
a_recon = tf.matmul(q, r)
self.assertAllClose(a_recon.eval(), a, rtol=tol, atol=tol)
def CheckUnitary(self, x):
# Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
xx = tf.matmul(tf.conj(x), x, transpose_a=True)
identity = tf.matrix_band_part(tf.ones_like(xx), 0, 0)
if is_single:
tol = 1e-5
else:
tol = 1e-14
self.assertAllClose(identity.eval(), xx.eval(), atol=tol)
def Test(self):
np.random.seed(1)
x_np = np.random.uniform(
low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_)
if is_complex:
x_np += 1j * np.random.uniform(
low=-1.0, high=1.0,
size=np.prod(shape_)).reshape(shape_).astype(dtype_)
for full_matrices in False, True:
with self.test_session() as sess:
if use_static_shape_:
x_tf = tf.constant(x_np)
else:
x_tf = tf.placeholder(dtype_)
q_tf, r_tf = tf.qr(x_tf, full_matrices=full_matrices)
if use_static_shape_:
q_tf_val, r_tf_val = sess.run([q_tf, r_tf])
else:
q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
q_dims = q_tf_val.shape
np_q = np.ndarray(q_dims, dtype_)
np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1]))
new_first_dim = np_q_reshape.shape[0]
x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1]))
for i in range(new_first_dim):
if full_matrices:
np_q_reshape[i,:,:], _ = \
np.linalg.qr(x_reshape[i,:,:], mode="complete")
else:
np_q_reshape[i,:,:], _ = \
np.linalg.qr(x_reshape[i,:,:], mode="reduced")
np_q = np.reshape(np_q_reshape, q_dims)
CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:]))
CheckApproximation(self, x_np, q_tf_val, r_tf_val)
CheckUnitary(self, q_tf_val)
return Test
if __name__ == "__main__":
for dtype in np.float32, np.float64, np.complex64, np.complex128:
for rows in 1, 2, 5, 10, 32, 100:
for cols in 1, 2, 5, 10, 32, 100:
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
shape = batch_dims + (rows, cols)
for use_static_shape in True, False:
name = "%s_%s_%s" % (dtype.__name__, "_".join(map(str, shape)),
use_static_shape)
setattr(QrOpTest, "testQr_" + name,
_GetQrOpTest(dtype, shape, use_static_shape))
tf.test.main()
......@@ -24,7 +24,7 @@ import tensorflow as tf
class SvdOpTest(tf.test.TestCase):
def testWrongDimensions(self):
# The input to batch_svd should be a tensor of at least rank 2.
# The input to svd should be a tensor of at least rank 2.
scalar = tf.constant(1.)
with self.assertRaisesRegexp(ValueError,
"Shape must be at least rank 2 but is rank 0"):
......@@ -35,7 +35,7 @@ class SvdOpTest(tf.test.TestCase):
tf.svd(vector)
def _GetSvdOpTest(dtype_, shape_):
def _GetSvdOpTest(dtype_, shape_, use_static_shape_):
is_complex = dtype_ in (np.complex64, np.complex128)
is_single = dtype_ in (np.float32, np.complex64)
......@@ -101,40 +101,61 @@ def _GetSvdOpTest(dtype_, shape_):
def Test(self):
np.random.seed(1)
x = np.random.uniform(
x_np = np.random.uniform(
low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_)
if is_complex:
x += 1j * np.random.uniform(
x_np += 1j * np.random.uniform(
low=-1.0, high=1.0,
size=np.prod(shape_)).reshape(shape_).astype(dtype_)
for compute_uv in False, True:
for full_matrices in False, True:
with self.test_session():
with self.test_session() as sess:
if use_static_shape_:
x_tf = tf.constant(x_np)
else:
x_tf = tf.placeholder(dtype_)
if compute_uv:
tf_s, tf_u, tf_v = tf.svd(tf.constant(x),
s_tf, u_tf, v_tf = tf.svd(x_tf,
compute_uv=compute_uv,
full_matrices=full_matrices)
if use_static_shape_:
s_tf_val, u_tf_val, v_tf_val = sess.run([s_tf, u_tf, v_tf])
else:
s_tf_val, u_tf_val, v_tf_val = sess.run([s_tf, u_tf, v_tf],
feed_dict={x_tf: x_np})
else:
tf_s = tf.svd(tf.constant(x),
s_tf = tf.svd(x_tf,
compute_uv=compute_uv,
full_matrices=full_matrices)
if use_static_shape_:
s_tf_val = sess.run(s_tf)
else:
s_tf_val = sess.run(s_tf, feed_dict={x_tf: x_np})
if compute_uv:
np_u, np_s, np_v = np.linalg.svd(x,
u_np, s_np, v_np = np.linalg.svd(x_np,
compute_uv=compute_uv,
full_matrices=full_matrices)
else:
np_s = np.linalg.svd(x,
s_np = np.linalg.svd(x_np,
compute_uv=compute_uv,
full_matrices=full_matrices)
CompareSingularValues(self, np_s, tf_s.eval())
# We explicitly avoid the situation where numpy eliminates a first
# dimension that is equal to one
s_np = np.reshape(s_np, s_tf_val.shape)
CompareSingularValues(self, s_np, s_tf_val)
if compute_uv:
CompareSingularVectors(self, np_u, tf_u.eval(), min(shape_[-2:]))
CompareSingularVectors(self, np.conj(np.swapaxes(np_v, -2, -1)),
tf_v.eval(), min(shape_[-2:]))
CheckApproximation(self, x, tf_u, tf_s, tf_v, full_matrices)
CheckUnitary(self, tf_u)
CheckUnitary(self, tf_v)
CompareSingularVectors(self, u_np, u_tf_val, min(shape_[-2:]))
CompareSingularVectors(self,
np.conj(np.swapaxes(v_np, -2, -1)), v_tf_val,
min(shape_[-2:]))
CheckApproximation(self, x_np, u_tf_val, s_tf_val, v_tf_val,
full_matrices)
CheckUnitary(self, u_tf_val)
CheckUnitary(self, v_tf_val)
return Test
......@@ -145,6 +166,9 @@ if __name__ == "__main__":
for cols in 1, 2, 5, 10, 32, 100:
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
shape = batch_dims + (rows, cols)
name = "%s_%s" % (dtype.__name__, "_".join(map(str, shape)))
setattr(SvdOpTest, "testSvd_" + name, _GetSvdOpTest(dtype, shape))
for use_static_shape in True, False:
name = "%s_%s_%s" % (dtype.__name__, "_".join(map(str, shape)),
use_static_shape)
setattr(SvdOpTest, "testSvd_" + name,
_GetSvdOpTest(dtype, shape, use_static_shape))
tf.test.main()
......@@ -104,6 +104,7 @@ functions on matrices to your graph.
@@matrix_solve
@@matrix_triangular_solve
@@matrix_solve_ls
@@qr
@@self_adjoint_eig
@@self_adjoint_eigvals
@@svd
......@@ -949,6 +950,7 @@ def div(x, y, name=None):
def div_deprecated(x, y, name=None):
return gen_math_ops.div(x, y, name)
mod = gen_math_ops.floor_mod
......@@ -1001,6 +1003,7 @@ def floordiv_deprecated(x, y, name=None):
# return gen_math_ops.floor_div(x, y, name=name)
return gen_math_ops.div(x, y, name=name)
realdiv = gen_math_ops.real_div
truncatediv = gen_math_ops.truncate_div
# TODO(aselle): Rename this to floordiv when we can.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册