Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
715f951e
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
715f951e
编写于
11月 22, 2016
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
11月 22, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Adding a TF QR op.
Change: 139959769
上级
dc4b868b
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
538 addition
and
47 deletion
+538
-47
tensorflow/contrib/cmake/tf_tests.cmake
tensorflow/contrib/cmake/tf_tests.cmake
+30
-26
tensorflow/core/kernels/BUILD
tensorflow/core/kernels/BUILD
+1
-0
tensorflow/core/kernels/qr_op_complex128.cc
tensorflow/core/kernels/qr_op_complex128.cc
+22
-0
tensorflow/core/kernels/qr_op_complex64.cc
tensorflow/core/kernels/qr_op_complex64.cc
+22
-0
tensorflow/core/kernels/qr_op_double.cc
tensorflow/core/kernels/qr_op_double.cc
+22
-0
tensorflow/core/kernels/qr_op_float.cc
tensorflow/core/kernels/qr_op_float.cc
+22
-0
tensorflow/core/kernels/qr_op_impl.h
tensorflow/core/kernels/qr_op_impl.h
+110
-0
tensorflow/core/ops/linalg_ops.cc
tensorflow/core/ops/linalg_ops.cc
+65
-3
tensorflow/core/ops/linalg_ops_test.cc
tensorflow/core/ops/linalg_ops_test.cc
+51
-0
tensorflow/python/BUILD
tensorflow/python/BUILD
+2
-0
tensorflow/python/kernel_tests/BUILD
tensorflow/python/kernel_tests/BUILD
+9
-0
tensorflow/python/kernel_tests/qr_op_test.py
tensorflow/python/kernel_tests/qr_op_test.py
+137
-0
tensorflow/python/kernel_tests/svd_op_test.py
tensorflow/python/kernel_tests/svd_op_test.py
+42
-18
tensorflow/python/ops/math_ops.py
tensorflow/python/ops/math_ops.py
+3
-0
未找到文件。
tensorflow/contrib/cmake/tf_tests.cmake
浏览文件 @
715f951e
...
...
@@ -4,16 +4,16 @@ enable_testing()
# get a temp path for test data
#
function
(
GetTestRunPath VAR_NAME OBJ_NAME
)
if
(
WIN32
)
if
(
DEFINED ENV{TMP}
)
if
(
WIN32
)
if
(
DEFINED ENV{TMP}
)
set
(
TMPDIR
"$ENV{TMP}"
)
elseif
(
DEFINED ENV{TEMP}
)
set
(
TMPDIR
"$ENV{TEMP}"
)
endif
()
string
(
REPLACE
"
\\
"
"/"
TMPDIR
${
TMPDIR
}
)
else
()
set
(
TMPDIR
"$ENV{TMPDIR}"
)
endif
()
else
()
set
(
TMPDIR
"$ENV{TMPDIR}"
)
endif
()
if
(
NOT EXISTS
"
${
TMPDIR
}
"
)
message
(
FATAL_ERROR
"Unable to determine a path to the temporary directory"
)
endif
()
...
...
@@ -45,7 +45,7 @@ endfunction(AddTests)
#
function
(
AddTest
)
cmake_parse_arguments
(
_AT
""
"TARGET"
"SOURCES;OBJECTS;LIBS;DATA;DEPENDS"
${
ARGN
}
)
list
(
REMOVE_DUPLICATES _AT_SOURCES
)
list
(
REMOVE_DUPLICATES _AT_OBJECTS
)
list
(
REMOVE_DUPLICATES _AT_LIBS
)
...
...
@@ -55,7 +55,7 @@ function(AddTest)
if
(
_AT_DEPENDS
)
list
(
REMOVE_DUPLICATES _AT_DEPENDS
)
endif
(
_AT_DEPENDS
)
add_executable
(
${
_AT_TARGET
}
${
_AT_SOURCES
}
${
_AT_OBJECTS
}
)
target_link_libraries
(
${
_AT_TARGET
}
${
_AT_LIBS
}
)
...
...
@@ -96,7 +96,7 @@ function(AddPythonTests)
if
(
_AT_DEPENDS
)
list
(
REMOVE_DUPLICATES _AT_DEPENDS
)
endif
(
_AT_DEPENDS
)
foreach
(
sourcefile
${
_AT_SOURCES
}
)
add_test
(
NAME
${
sourcefile
}
COMMAND
${
PYTHON_EXECUTABLE
}
${
sourcefile
}
)
if
(
_AT_DEPENDS
)
...
...
@@ -108,11 +108,11 @@ endfunction(AddPythonTests)
if
(
tensorflow_BUILD_PYTHON_TESTS
)
#
# python tests. This assumes that the tensorflow wheel is
# installed on the test system.
# installed on the test system.
# TODO: we currently don't handle tests that need to have
# some environment setup: see AddTest how to add this
#
# include all test
file
(
GLOB_RECURSE tf_test_src_py
"
${
tensorflow_source_dir
}
/tensorflow/python/kernel_tests/*.py"
...
...
@@ -124,14 +124,14 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"
${
tensorflow_source_dir
}
/tensorflow/python/kernel_tests/__init__.py"
"
${
tensorflow_source_dir
}
/tensorflow/python/kernel_tests/benchmark_test.py"
"
${
tensorflow_source_dir
}
/tensorflow/python/kernel_tests/resource_variable_ops_test.py"
)
)
if
(
WIN32
)
set
(
tf_test_src_py_exclude
${
tf_test_src_py_exclude
}
# generally excluded
"
${
tensorflow_source_dir
}
/tensorflow/python/kernel_tests/__init__.py"
# TODO: failing tests.
# TODO: failing tests.
# Nothing critical in here but should get this list down to []
# The failing list is grouped by failure source
# stl on windows handles overflows different
...
...
@@ -148,9 +148,13 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"
${
tensorflow_source_dir
}
/tensorflow/python/kernel_tests/py_func_test.py"
# issues related to windows fs
"
${
tensorflow_source_dir
}
/tensorflow/python/kernel_tests/io_ops_test.py"
# missing kernel
# missing kernel
"
${
tensorflow_source_dir
}
/tensorflow/python/kernel_tests/conv_ops_test.py"
"
${
tensorflow_source_dir
}
/tensorflow/python/kernel_tests/depthwise_conv_op_test.py"
"
${
tensorflow_source_dir
}
/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py"
"
${
tensorflow_source_dir
}
/tensorflow/python/kernel_tests/pool_test.py"
"
${
tensorflow_source_dir
}
/tensorflow/python/kernel_tests/qr_op_test.py"
"
${
tensorflow_source_dir
}
/tensorflow/python/kernel_tests/svd_op_test.py"
# cuda launch failed
"
${
tensorflow_source_dir
}
/tensorflow/python/kernel_tests/diag_op_test.py"
"
${
tensorflow_source_dir
}
/tensorflow/python/kernel_tests/trace_op_test.py"
...
...
@@ -158,10 +162,10 @@ if (tensorflow_BUILD_PYTHON_TESTS)
)
endif
()
list
(
REMOVE_ITEM tf_test_src_py
${
tf_test_src_py_exclude
}
)
AddPythonTests
(
SOURCES
${
tf_test_src_py
}
)
)
endif
(
tensorflow_BUILD_PYTHON_TESTS
)
if
(
tensorflow_BUILD_CC_TESTS
)
...
...
@@ -169,9 +173,9 @@ if (tensorflow_BUILD_CC_TESTS)
# cc unit tests. Be aware that by default we include 250+ tests which
# will take time and space to build.
# If you wan to cut this down, for example to a specific test, modify
# tf_test_src_simple to your needs
# tf_test_src_simple to your needs
#
include_directories
(
${
googletest_INCLUDE_DIRS
}
)
# cc tests wrapper
...
...
@@ -228,7 +232,7 @@ if (tensorflow_BUILD_CC_TESTS)
# generally excluded
"
${
tensorflow_source_dir
}
/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc"
"
${
tensorflow_source_dir
}
/tensorflow/cc/framework/cc_ops_test.cc"
# test_op.h missing
# TODO: test failing
"
${
tensorflow_source_dir
}
/tensorflow/core/common_runtime/simple_placer_test.cc"
"
${
tensorflow_source_dir
}
/tensorflow/core/distributed_runtime/executor_test.cc"
...
...
@@ -254,7 +258,7 @@ if (tensorflow_BUILD_CC_TESTS)
"
${
tensorflow_source_dir
}
/tensorflow/contrib/rnn/ops/gru_ops_test.cc"
# status 5
"
${
tensorflow_source_dir
}
/tensorflow/contrib/rnn/ops/lstm_ops_test.cc"
# status 5
# TODO: not compiling
# TODO: not compiling
"
${
tensorflow_source_dir
}
/tensorflow/cc/framework/gradient_checker_test.cc"
"
${
tensorflow_source_dir
}
/tensorflow/cc/gradients/math_grad_test.cc"
"
${
tensorflow_source_dir
}
/tensorflow/cc/gradients/array_grad_test.cc"
...
...
@@ -344,13 +348,13 @@ if (tensorflow_BUILD_CC_TESTS)
endif
()
list
(
REMOVE_ITEM tf_test_src_simple
${
tf_test_src_simple_exclude
}
)
set
(
tf_test_lib tf_test_lib
)
add_library
(
${
tf_test_lib
}
STATIC
${
tf_src_testlib
}
)
# this is giving to much objects and libraries to the linker but
# this is giving to much objects and libraries to the linker but
# it makes this script much easier. So for now we do it this way.
set
(
tf_obj_test
set
(
tf_obj_test
$<TARGET_OBJECTS:tf_core_lib>
$<TARGET_OBJECTS:tf_core_cpu>
$<TARGET_OBJECTS:tf_core_framework>
...
...
@@ -362,10 +366,10 @@ if (tensorflow_BUILD_CC_TESTS)
$<$<BOOL:
${
tensorflow_ENABLE_GPU
}
>:$<TARGET_OBJECTS:tf_stream_executor>>
)
set
(
tf_test_libs
set
(
tf_test_libs
tf_protos_cc
tf_test_lib
${
tf_core_gpu_kernels_lib
}
${
tf_core_gpu_kernels_lib
}
${
googletest_STATIC_LIBRARIES
}
${
tensorflow_EXTERNAL_LIBRARIES
}
)
...
...
@@ -373,7 +377,7 @@ if (tensorflow_BUILD_CC_TESTS)
AddTests
(
SOURCES
${
tf_test_src_simple
}
OBJECTS
${
tf_obj_test
}
LIBS
${
tf_test_libs
}
LIBS
${
tf_test_libs
}
DEPENDS googletest
)
endif
(
tensorflow_BUILD_CC_TESTS
)
tensorflow/core/kernels/BUILD
浏览文件 @
715f951e
...
...
@@ -1324,6 +1324,7 @@ tf_kernel_libraries(
"matrix_solve_ls_op"
,
"matrix_solve_op"
,
"matrix_triangular_solve_op"
,
"qr_op"
,
"svd_op"
,
],
deps
=
[
...
...
tensorflow/core/kernels/qr_op_complex128.cc
0 → 100644
浏览文件 @
715f951e
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/qr_op_impl.h"
namespace
tensorflow
{
REGISTER_LINALG_OP
(
"Qr"
,
(
QrOp
<
complex128
>
),
complex128
);
}
// namespace tensorflow
tensorflow/core/kernels/qr_op_complex64.cc
0 → 100644
浏览文件 @
715f951e
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/qr_op_impl.h"
namespace
tensorflow
{
REGISTER_LINALG_OP
(
"Qr"
,
(
QrOp
<
complex64
>
),
complex64
);
}
// namespace tensorflow
tensorflow/core/kernels/qr_op_double.cc
0 → 100644
浏览文件 @
715f951e
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/qr_op_impl.h"
namespace
tensorflow
{
REGISTER_LINALG_OP
(
"Qr"
,
(
QrOp
<
double
>
),
double
);
}
// namespace tensorflow
tensorflow/core/kernels/qr_op_float.cc
0 → 100644
浏览文件 @
715f951e
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/qr_op_impl.h"
namespace
tensorflow
{
REGISTER_LINALG_OP
(
"Qr"
,
(
QrOp
<
float
>
),
float
);
}
// namespace tensorflow
tensorflow/core/kernels/qr_op_impl.h
0 → 100644
浏览文件 @
715f951e
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/linalg_ops.cc.
//
// This header file is used by the individual qr_*op*.cc files for registering
// individual kernels. A separate file is used for each instantiated kernel to
// improve compilation times.
#include <algorithm>
#include "third_party/eigen3/Eigen/QR"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace
tensorflow
{
template
<
class
Scalar
>
class
QrOp
:
public
LinearAlgebraOp
<
Scalar
>
{
public:
typedef
LinearAlgebraOp
<
Scalar
>
Base
;
explicit
QrOp
(
OpKernelConstruction
*
context
)
:
Base
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"full_matrices"
,
&
full_matrices_
));
}
using
TensorShapes
=
typename
Base
::
TensorShapes
;
void
ValidateInputMatrixShapes
(
OpKernelContext
*
context
,
const
TensorShapes
&
input_matrix_shapes
)
const
final
{
Base
::
ValidateSingleMatrix
(
context
,
input_matrix_shapes
);
}
TensorShapes
GetOutputMatrixShapes
(
const
TensorShapes
&
input_matrix_shapes
)
const
final
{
int64
m
=
input_matrix_shapes
[
0
].
dim_size
(
0
);
int64
n
=
input_matrix_shapes
[
0
].
dim_size
(
1
);
int64
min_size
=
std
::
min
(
m
,
n
);
if
(
full_matrices_
)
{
return
TensorShapes
({
TensorShape
({
m
,
m
}),
TensorShape
({
m
,
n
})});
}
else
{
return
TensorShapes
(
{
TensorShape
({
m
,
min_size
}),
TensorShape
({
min_size
,
n
})});
}
}
int64
GetCostPerUnit
(
const
TensorShapes
&
input_matrix_shapes
)
const
final
{
double
m
=
static_cast
<
double
>
(
input_matrix_shapes
[
0
].
dim_size
(
0
));
double
n
=
static_cast
<
double
>
(
input_matrix_shapes
[
0
].
dim_size
(
1
));
double
max_size
=
std
::
max
(
m
,
n
);
double
min_size
=
std
::
min
(
m
,
n
);
double
cost
=
2
*
max_size
*
min_size
*
min_size
-
2
*
min_size
*
min_size
*
min_size
/
3.
;
// TODO(jpoulson): Increase the cost if full_matrices is true in a manner
// that reflects the algorithm used for the expansion.
return
cost
>=
static_cast
<
double
>
(
kint64max
)
?
kint64max
:
static_cast
<
int64
>
(
cost
);
}
using
Matrix
=
typename
Base
::
Matrix
;
using
MatrixMaps
=
typename
Base
::
MatrixMaps
;
using
ConstMatrixMap
=
typename
Base
::
ConstMatrixMap
;
using
ConstMatrixMaps
=
typename
Base
::
ConstMatrixMaps
;
void
ComputeMatrix
(
OpKernelContext
*
context
,
const
ConstMatrixMaps
&
inputs
,
MatrixMaps
*
outputs
)
final
{
Eigen
::
HouseholderQR
<
Matrix
>
qr
(
inputs
[
0
]);
const
int
m
=
inputs
[
0
].
rows
();
const
int
n
=
inputs
[
0
].
cols
();
const
int
min_size
=
std
::
min
(
m
,
n
);
if
(
full_matrices_
)
{
outputs
->
at
(
0
)
=
qr
.
householderQ
();
outputs
->
at
(
1
)
=
qr
.
matrixQR
().
template
triangularView
<
Eigen
::
Upper
>();
}
else
{
// TODO(jpoulson): Exploit the fact that Householder transformations can
// be expanded faster than they can be applied to an arbitrary matrix
// (Cf. LAPACK's DORGQR).
Matrix
tmp
=
Matrix
::
Identity
(
m
,
min_size
);
outputs
->
at
(
0
)
=
qr
.
householderQ
()
*
tmp
;
auto
qr_top
=
qr
.
matrixQR
().
block
(
0
,
0
,
min_size
,
n
);
outputs
->
at
(
1
)
=
qr_top
.
template
triangularView
<
Eigen
::
Upper
>();
}
}
private:
bool
full_matrices_
;
TF_DISALLOW_COPY_AND_ASSIGN
(
QrOp
);
};
}
// namespace tensorflow
tensorflow/core/ops/linalg_ops.cc
浏览文件 @
715f951e
...
...
@@ -109,6 +109,36 @@ Status SelfAdjointEigV2ShapeFn(InferenceContext* c) {
return
Status
::
OK
();
}
// Input is [...,M,N].
// First and second outputs are:
// [...,M,M]; [...,M,N], if full_matrices is true,
// [...,M,P]; [...,P,N], if full_matrices is false,
// where P = min(M,N).
Status
QrShapeFn
(
InferenceContext
*
c
)
{
ShapeHandle
input
;
TF_RETURN_IF_ERROR
(
c
->
WithRankAtLeast
(
c
->
input
(
0
),
2
,
&
input
));
DimensionHandle
m
=
c
->
Dim
(
input
,
-
2
);
DimensionHandle
n
=
c
->
Dim
(
input
,
-
1
);
DimensionHandle
p
;
TF_RETURN_IF_ERROR
(
c
->
Min
(
m
,
n
,
&
p
));
ShapeHandle
batch_shape
;
TF_RETURN_IF_ERROR
(
c
->
Subshape
(
input
,
0
,
-
2
,
&
batch_shape
));
ShapeHandle
q_shape
;
ShapeHandle
r_shape
;
bool
full_matrices
;
TF_RETURN_IF_ERROR
(
c
->
GetAttr
(
"full_matrices"
,
&
full_matrices
));
if
(
full_matrices
)
{
TF_RETURN_IF_ERROR
(
c
->
Concatenate
(
batch_shape
,
c
->
Matrix
(
m
,
m
),
&
q_shape
));
TF_RETURN_IF_ERROR
(
c
->
Concatenate
(
batch_shape
,
c
->
Matrix
(
m
,
n
),
&
r_shape
));
}
else
{
TF_RETURN_IF_ERROR
(
c
->
Concatenate
(
batch_shape
,
c
->
Matrix
(
m
,
p
),
&
q_shape
));
TF_RETURN_IF_ERROR
(
c
->
Concatenate
(
batch_shape
,
c
->
Matrix
(
p
,
n
),
&
r_shape
));
}
c
->
set_output
(
0
,
q_shape
);
c
->
set_output
(
1
,
r_shape
);
return
Status
::
OK
();
}
// Input is [...,M,N]. First output is [...,min(M,N)].
// Second and third outputs are:
// [0]; [0], if compute_uv is false.
...
...
@@ -435,6 +465,38 @@ Equivalent to np.linalg.lstsq
@end_compatibility
)doc"
);
REGISTER_OP
(
"Qr"
)
.
Input
(
"input: T"
)
.
Output
(
"q: T"
)
.
Output
(
"r: T"
)
.
Attr
(
"full_matrices: bool = False"
)
.
Attr
(
"T: {double, float, complex64, complex128}"
)
.
SetShapeFn
(
QrShapeFn
)
.
Doc
(
R"doc(
Computes the QR decompositions of one or more matrices.
Computes the QR decomposition of each inner matrix in `tensor` such that
`tensor[..., :, :] = q[..., :, :] * r[..., :,:])`
```prettyprint
# a is a tensor.
# q is a tensor of orthonormal matrices.
# r is a tensor of upper triangular matrices.
q, r = qr(a)
q_full, r_full = qr(a, full_matrices=True)
```
input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions
form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`.
q: Orthonormal basis for range of `a`. If `full_matrices` is `False` then
shape is `[..., M, P]`; if `full_matrices` is `True` then shape is
`[..., M, M]`.
r: Triangular factor. If `full_matrices` is `False` then shape is
`[..., P, N]`. If `full_matrices` is `True` then shape is `[..., M, N]`.
full_matrices: If true, compute full-sized `q` and `r`. If false
(the default), compute only the leading `P` columns of `q`.
)doc"
);
REGISTER_OP
(
"Svd"
)
.
Input
(
"input: T"
)
.
Output
(
"s: T"
)
...
...
@@ -463,10 +525,10 @@ input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions
form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`.
s: Singular values. Shape is `[..., P]`.
u: Left singular vectors. If `full_matrices` is `False` then shape is
`[..., M,
M
]`; if `full_matrices` is `True` then shape is
`[..., M,
P
]`. Undefined if `compute_uv` is `False`.
`[..., M,
P
]`; if `full_matrices` is `True` then shape is
`[..., M,
M
]`. Undefined if `compute_uv` is `False`.
v: Left singular vectors. If `full_matrices` is `False` then shape is
`[..., N,
N]`. If `full_matrices` is `True` then shape is `[..., N, P
]`.
`[..., N,
P]`. If `full_matrices` is `True` then shape is `[..., N, N
]`.
Undefined if `compute_uv` is false.
compute_uv: If true, left and right singular vectors will be
computed and returned in `u` and `v`, respectively.
...
...
tensorflow/core/ops/linalg_ops_test.cc
浏览文件 @
715f951e
...
...
@@ -171,6 +171,50 @@ TEST(LinalgOpsTest, MatrixSolveLs_ShapeFn) {
INFER_ERROR
(
"Shape must be rank 0 but is rank 1"
,
op
,
"?;?;[1]"
);
}
TEST
(
LinalgOpsTest
,
Qr_ShapeFn
)
{
ShapeInferenceTestOp
op
(
"Qr"
);
auto
set_attrs
=
[
&
op
](
bool
full_matrices
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test"
,
"Qr"
)
.
Input
({
"input"
,
0
,
DT_FLOAT
})
.
Attr
(
"full_matrices"
,
full_matrices
)
.
Finalize
(
&
op
.
node_def
));
};
// Defining `P` = min(`M`, `N`), if full_matrices = False, then Q should be
// `M` x `P` and `R` should be `P` x `N`. Otherwise, Q should be
// `M` x `M` and `R` should be `M` x `N`.
//
// For rank-3 tensors, `M` = d0_1 and `N` = d0_2.
//
set_attrs
(
false
);
INFER_OK
(
op
,
"?"
,
"?;?"
);
INFER_OK
(
op
,
"[?,?,?]"
,
"[d0_0,d0_1,?];[d0_0,?,d0_2]"
);
INFER_OK
(
op
,
"[4,?,?]"
,
"[d0_0,d0_1,?];[d0_0,?,d0_2]"
);
INFER_OK
(
op
,
"[4,2,?]"
,
"[d0_0,d0_1,?];[d0_0,?,d0_2]"
);
INFER_OK
(
op
,
"[4,?,2]"
,
"[d0_0,d0_1,?];[d0_0,?,d0_2]"
);
INFER_OK
(
op
,
"[?,2,2]"
,
"[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]"
);
INFER_OK
(
op
,
"[4,2,2]"
,
"[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]"
);
INFER_OK
(
op
,
"[?,3,2]"
,
"[d0_0,d0_1,d0_2];[d0_0,d0_2,d0_2]"
);
INFER_OK
(
op
,
"[4,3,2]"
,
"[d0_0,d0_1,d0_2];[d0_0,d0_2,d0_2]"
);
INFER_OK
(
op
,
"[?,2,3]"
,
"[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]"
);
INFER_OK
(
op
,
"[4,2,3]"
,
"[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]"
);
INFER_ERROR
(
"Shape must be at least rank 2 but is rank 1"
,
op
,
"[1]"
);
set_attrs
(
true
);
INFER_OK
(
op
,
"?"
,
"?;?"
);
INFER_OK
(
op
,
"[?,?,?]"
,
"[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]"
);
INFER_OK
(
op
,
"[4,?,?]"
,
"[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]"
);
INFER_OK
(
op
,
"[4,2,?]"
,
"[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]"
);
INFER_OK
(
op
,
"[4,?,2]"
,
"[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]"
);
INFER_OK
(
op
,
"[?,2,2]"
,
"[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]"
);
INFER_OK
(
op
,
"[4,2,2]"
,
"[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]"
);
INFER_OK
(
op
,
"[?,3,2]"
,
"[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]"
);
INFER_OK
(
op
,
"[4,3,2]"
,
"[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]"
);
INFER_OK
(
op
,
"[?,2,3]"
,
"[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]"
);
INFER_OK
(
op
,
"[4,2,3]"
,
"[d0_0,d0_1,d0_1];[d0_0,d0_1,d0_2]"
);
INFER_ERROR
(
"Shape must be at least rank 2 but is rank 1"
,
op
,
"[1]"
);
}
TEST
(
LinalgOpsTest
,
Svd_ShapeFn
)
{
ShapeInferenceTestOp
op
(
"Svd"
);
auto
set_attrs
=
[
&
op
](
bool
compute_uv
,
bool
full_matrices
)
{
...
...
@@ -180,6 +224,13 @@ TEST(LinalgOpsTest, Svd_ShapeFn) {
.
Attr
(
"full_matrices"
,
full_matrices
)
.
Finalize
(
&
op
.
node_def
));
};
// Defining `P` = min(`M`, `N`), if full_matrices = False, then U should be
// `M` x `P` and `V` should be `N` x `P`. Otherwise, U should be
// `M` x `M` and `V` should be `N` x `N`.
//
// For rank-3 tensors, `M` = d0_1 and `N` = d0_2.
//
set_attrs
(
false
,
false
);
INFER_OK
(
op
,
"?"
,
"?;[0];[0]"
);
INFER_OK
(
op
,
"[?,?,?]"
,
"[d0_0,?];[0];[0]"
);
...
...
tensorflow/python/BUILD
浏览文件 @
715f951e
...
...
@@ -1030,8 +1030,10 @@ py_library(
srcs_version
=
"PY2AND3"
,
deps
=
[
":array_ops"
,
":control_flow_ops_gen"
,
":data_flow_ops_gen"
,
":framework"
,
":framework_for_generated_wrappers"
,
":math_ops_gen"
,
":sparse_ops_gen"
,
":state_ops"
,
...
...
tensorflow/python/kernel_tests/BUILD
浏览文件 @
715f951e
...
...
@@ -1355,6 +1355,15 @@ cuda_py_test(
shard_count
=
20
,
)
cuda_py_test
(
name
=
"qr_op_test"
,
size
=
"medium"
,
srcs
=
[
"qr_op_test.py"
],
additional_deps
=
[
"//tensorflow:tensorflow_py"
],
shard_count
=
20
,
tags
=
[
"nomsan"
],
# fails in msan from numpy calls
)
cuda_py_test
(
name
=
"svd_op_test"
,
size
=
"medium"
,
...
...
tensorflow/python/kernel_tests/qr_op_test.py
0 → 100644
浏览文件 @
715f951e
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.math_ops.matrix_inverse."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
class
QrOpTest
(
tf
.
test
.
TestCase
):
def
testWrongDimensions
(
self
):
# The input to qr should be a tensor of at least rank 2.
scalar
=
tf
.
constant
(
1.
)
with
self
.
assertRaisesRegexp
(
ValueError
,
"Shape must be at least rank 2 but is rank 0"
):
tf
.
qr
(
scalar
)
vector
=
tf
.
constant
([
1.
,
2.
])
with
self
.
assertRaisesRegexp
(
ValueError
,
"Shape must be at least rank 2 but is rank 1"
):
tf
.
qr
(
vector
)
def
_GetQrOpTest
(
dtype_
,
shape_
,
use_static_shape_
):
is_complex
=
dtype_
in
(
np
.
complex64
,
np
.
complex128
)
is_single
=
dtype_
in
(
np
.
float32
,
np
.
complex64
)
def
CompareOrthogonal
(
self
,
x
,
y
,
rank
):
if
is_single
:
atol
=
5e-4
else
:
atol
=
5e-14
# We only compare the first 'rank' orthogonal vectors since the
# remainder form an arbitrary orthonormal basis for the
# (row- or column-) null space, whose exact value depends on
# implementation details. Notice that since we check that the
# matrices of singular vectors are unitary elsewhere, we do
# implicitly test that the trailing vectors of x and y span the
# same space.
x
=
x
[...,
0
:
rank
]
y
=
y
[...,
0
:
rank
]
# Q is only unique up to sign (complex phase factor for complex matrices),
# so we normalize the sign first.
sum_of_ratios
=
np
.
sum
(
np
.
divide
(
y
,
x
),
-
2
,
keepdims
=
True
)
phases
=
np
.
divide
(
sum_of_ratios
,
np
.
abs
(
sum_of_ratios
))
x
*=
phases
self
.
assertAllClose
(
x
,
y
,
atol
=
atol
)
def
CheckApproximation
(
self
,
a
,
q
,
r
):
if
is_single
:
tol
=
1e-5
else
:
tol
=
1e-14
# Tests that a ~= q*r.
a_recon
=
tf
.
matmul
(
q
,
r
)
self
.
assertAllClose
(
a_recon
.
eval
(),
a
,
rtol
=
tol
,
atol
=
tol
)
def
CheckUnitary
(
self
,
x
):
# Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
xx
=
tf
.
matmul
(
tf
.
conj
(
x
),
x
,
transpose_a
=
True
)
identity
=
tf
.
matrix_band_part
(
tf
.
ones_like
(
xx
),
0
,
0
)
if
is_single
:
tol
=
1e-5
else
:
tol
=
1e-14
self
.
assertAllClose
(
identity
.
eval
(),
xx
.
eval
(),
atol
=
tol
)
def
Test
(
self
):
np
.
random
.
seed
(
1
)
x_np
=
np
.
random
.
uniform
(
low
=-
1.0
,
high
=
1.0
,
size
=
np
.
prod
(
shape_
)).
reshape
(
shape_
).
astype
(
dtype_
)
if
is_complex
:
x_np
+=
1j
*
np
.
random
.
uniform
(
low
=-
1.0
,
high
=
1.0
,
size
=
np
.
prod
(
shape_
)).
reshape
(
shape_
).
astype
(
dtype_
)
for
full_matrices
in
False
,
True
:
with
self
.
test_session
()
as
sess
:
if
use_static_shape_
:
x_tf
=
tf
.
constant
(
x_np
)
else
:
x_tf
=
tf
.
placeholder
(
dtype_
)
q_tf
,
r_tf
=
tf
.
qr
(
x_tf
,
full_matrices
=
full_matrices
)
if
use_static_shape_
:
q_tf_val
,
r_tf_val
=
sess
.
run
([
q_tf
,
r_tf
])
else
:
q_tf_val
,
r_tf_val
=
sess
.
run
([
q_tf
,
r_tf
],
feed_dict
=
{
x_tf
:
x_np
})
q_dims
=
q_tf_val
.
shape
np_q
=
np
.
ndarray
(
q_dims
,
dtype_
)
np_q_reshape
=
np
.
reshape
(
np_q
,
(
-
1
,
q_dims
[
-
2
],
q_dims
[
-
1
]))
new_first_dim
=
np_q_reshape
.
shape
[
0
]
x_reshape
=
np
.
reshape
(
x_np
,
(
-
1
,
x_np
.
shape
[
-
2
],
x_np
.
shape
[
-
1
]))
for
i
in
range
(
new_first_dim
):
if
full_matrices
:
np_q_reshape
[
i
,:,:],
_
=
\
np
.
linalg
.
qr
(
x_reshape
[
i
,:,:],
mode
=
"complete"
)
else
:
np_q_reshape
[
i
,:,:],
_
=
\
np
.
linalg
.
qr
(
x_reshape
[
i
,:,:],
mode
=
"reduced"
)
np_q
=
np
.
reshape
(
np_q_reshape
,
q_dims
)
CompareOrthogonal
(
self
,
np_q
,
q_tf_val
,
min
(
shape_
[
-
2
:]))
CheckApproximation
(
self
,
x_np
,
q_tf_val
,
r_tf_val
)
CheckUnitary
(
self
,
q_tf_val
)
return
Test
if
__name__
==
"__main__"
:
for
dtype
in
np
.
float32
,
np
.
float64
,
np
.
complex64
,
np
.
complex128
:
for
rows
in
1
,
2
,
5
,
10
,
32
,
100
:
for
cols
in
1
,
2
,
5
,
10
,
32
,
100
:
for
batch_dims
in
[(),
(
3
,)]
+
[(
3
,
2
)]
*
(
max
(
rows
,
cols
)
<
10
):
shape
=
batch_dims
+
(
rows
,
cols
)
for
use_static_shape
in
True
,
False
:
name
=
"%s_%s_%s"
%
(
dtype
.
__name__
,
"_"
.
join
(
map
(
str
,
shape
)),
use_static_shape
)
setattr
(
QrOpTest
,
"testQr_"
+
name
,
_GetQrOpTest
(
dtype
,
shape
,
use_static_shape
))
tf
.
test
.
main
()
tensorflow/python/kernel_tests/svd_op_test.py
浏览文件 @
715f951e
...
...
@@ -24,7 +24,7 @@ import tensorflow as tf
class
SvdOpTest
(
tf
.
test
.
TestCase
):
def
testWrongDimensions
(
self
):
# The input to
batch_
svd should be a tensor of at least rank 2.
# The input to svd should be a tensor of at least rank 2.
scalar
=
tf
.
constant
(
1.
)
with
self
.
assertRaisesRegexp
(
ValueError
,
"Shape must be at least rank 2 but is rank 0"
):
...
...
@@ -35,7 +35,7 @@ class SvdOpTest(tf.test.TestCase):
tf
.
svd
(
vector
)
def
_GetSvdOpTest
(
dtype_
,
shape_
):
def
_GetSvdOpTest
(
dtype_
,
shape_
,
use_static_shape_
):
is_complex
=
dtype_
in
(
np
.
complex64
,
np
.
complex128
)
is_single
=
dtype_
in
(
np
.
float32
,
np
.
complex64
)
...
...
@@ -101,40 +101,61 @@ def _GetSvdOpTest(dtype_, shape_):
def
Test
(
self
):
np
.
random
.
seed
(
1
)
x
=
np
.
random
.
uniform
(
x
_np
=
np
.
random
.
uniform
(
low
=-
1.0
,
high
=
1.0
,
size
=
np
.
prod
(
shape_
)).
reshape
(
shape_
).
astype
(
dtype_
)
if
is_complex
:
x
+=
1j
*
np
.
random
.
uniform
(
x
_np
+=
1j
*
np
.
random
.
uniform
(
low
=-
1.0
,
high
=
1.0
,
size
=
np
.
prod
(
shape_
)).
reshape
(
shape_
).
astype
(
dtype_
)
for
compute_uv
in
False
,
True
:
for
full_matrices
in
False
,
True
:
with
self
.
test_session
():
with
self
.
test_session
()
as
sess
:
if
use_static_shape_
:
x_tf
=
tf
.
constant
(
x_np
)
else
:
x_tf
=
tf
.
placeholder
(
dtype_
)
if
compute_uv
:
tf_s
,
tf_u
,
tf_v
=
tf
.
svd
(
tf
.
constant
(
x
)
,
s_tf
,
u_tf
,
v_tf
=
tf
.
svd
(
x_tf
,
compute_uv
=
compute_uv
,
full_matrices
=
full_matrices
)
if
use_static_shape_
:
s_tf_val
,
u_tf_val
,
v_tf_val
=
sess
.
run
([
s_tf
,
u_tf
,
v_tf
])
else
:
s_tf_val
,
u_tf_val
,
v_tf_val
=
sess
.
run
([
s_tf
,
u_tf
,
v_tf
],
feed_dict
=
{
x_tf
:
x_np
})
else
:
tf_s
=
tf
.
svd
(
tf
.
constant
(
x
)
,
s_tf
=
tf
.
svd
(
x_tf
,
compute_uv
=
compute_uv
,
full_matrices
=
full_matrices
)
if
use_static_shape_
:
s_tf_val
=
sess
.
run
(
s_tf
)
else
:
s_tf_val
=
sess
.
run
(
s_tf
,
feed_dict
=
{
x_tf
:
x_np
})
if
compute_uv
:
np_u
,
np_s
,
np_v
=
np
.
linalg
.
svd
(
x
,
u_np
,
s_np
,
v_np
=
np
.
linalg
.
svd
(
x_np
,
compute_uv
=
compute_uv
,
full_matrices
=
full_matrices
)
else
:
np_s
=
np
.
linalg
.
svd
(
x
,
s_np
=
np
.
linalg
.
svd
(
x_np
,
compute_uv
=
compute_uv
,
full_matrices
=
full_matrices
)
CompareSingularValues
(
self
,
np_s
,
tf_s
.
eval
())
# We explicitly avoid the situation where numpy eliminates a first
# dimension that is equal to one
s_np
=
np
.
reshape
(
s_np
,
s_tf_val
.
shape
)
CompareSingularValues
(
self
,
s_np
,
s_tf_val
)
if
compute_uv
:
CompareSingularVectors
(
self
,
np_u
,
tf_u
.
eval
(),
min
(
shape_
[
-
2
:]))
CompareSingularVectors
(
self
,
np
.
conj
(
np
.
swapaxes
(
np_v
,
-
2
,
-
1
)),
tf_v
.
eval
(),
min
(
shape_
[
-
2
:]))
CheckApproximation
(
self
,
x
,
tf_u
,
tf_s
,
tf_v
,
full_matrices
)
CheckUnitary
(
self
,
tf_u
)
CheckUnitary
(
self
,
tf_v
)
CompareSingularVectors
(
self
,
u_np
,
u_tf_val
,
min
(
shape_
[
-
2
:]))
CompareSingularVectors
(
self
,
np
.
conj
(
np
.
swapaxes
(
v_np
,
-
2
,
-
1
)),
v_tf_val
,
min
(
shape_
[
-
2
:]))
CheckApproximation
(
self
,
x_np
,
u_tf_val
,
s_tf_val
,
v_tf_val
,
full_matrices
)
CheckUnitary
(
self
,
u_tf_val
)
CheckUnitary
(
self
,
v_tf_val
)
return
Test
...
...
@@ -145,6 +166,9 @@ if __name__ == "__main__":
for
cols
in
1
,
2
,
5
,
10
,
32
,
100
:
for
batch_dims
in
[(),
(
3
,)]
+
[(
3
,
2
)]
*
(
max
(
rows
,
cols
)
<
10
):
shape
=
batch_dims
+
(
rows
,
cols
)
name
=
"%s_%s"
%
(
dtype
.
__name__
,
"_"
.
join
(
map
(
str
,
shape
)))
setattr
(
SvdOpTest
,
"testSvd_"
+
name
,
_GetSvdOpTest
(
dtype
,
shape
))
for
use_static_shape
in
True
,
False
:
name
=
"%s_%s_%s"
%
(
dtype
.
__name__
,
"_"
.
join
(
map
(
str
,
shape
)),
use_static_shape
)
setattr
(
SvdOpTest
,
"testSvd_"
+
name
,
_GetSvdOpTest
(
dtype
,
shape
,
use_static_shape
))
tf
.
test
.
main
()
tensorflow/python/ops/math_ops.py
浏览文件 @
715f951e
...
...
@@ -104,6 +104,7 @@ functions on matrices to your graph.
@@matrix_solve
@@matrix_triangular_solve
@@matrix_solve_ls
@@qr
@@self_adjoint_eig
@@self_adjoint_eigvals
@@svd
...
...
@@ -949,6 +950,7 @@ def div(x, y, name=None):
def
div_deprecated
(
x
,
y
,
name
=
None
):
return
gen_math_ops
.
div
(
x
,
y
,
name
)
mod
=
gen_math_ops
.
floor_mod
...
...
@@ -1001,6 +1003,7 @@ def floordiv_deprecated(x, y, name=None):
# return gen_math_ops.floor_div(x, y, name=name)
return
gen_math_ops
.
div
(
x
,
y
,
name
=
name
)
realdiv
=
gen_math_ops
.
real_div
truncatediv
=
gen_math_ops
.
truncate_div
# TODO(aselle): Rename this to floordiv when we can.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录