提交 264bac93 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Use new adjoint attribute for solvers to make gradients more efficient.

Consolidate linalg shape inference functions.
Change: 119423897
上级 4d9ec5ec
......@@ -75,9 +75,7 @@ def _MatrixSolveGrad(op, grad):
"""Gradients for MatrixSolve."""
a = op.inputs[0]
c = op.outputs[0]
# TODO(rmlarsen): Get rid of explicit transpose after adding
# adjoint_a attribute to solver.
grad_b = linalg_ops.matrix_solve(array_ops.transpose(a), grad)
grad_b = linalg_ops.matrix_solve(a, grad, adjoint=True)
grad_a = -math_ops.matmul(grad_b, c, transpose_b=True)
return (grad_a, grad_b)
......@@ -87,10 +85,6 @@ def _BatchMatrixSolveGrad(op, grad):
"""Gradient for BatchMatrixSolve."""
a = op.inputs[0]
c = op.outputs[0]
# TODO(rmlarsen): Replace the following two lines with
# a single call to batch_matrix_solve after adding
# in an option to solve for A^T X = Y.
ainv = linalg_ops.batch_matrix_inverse(a)
grad_b = math_ops.batch_matmul(ainv, grad, adj_x=True)
grad_b = linalg_ops.batch_matrix_solve(a, grad, adjoint=True)
grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True)
return (grad_a, grad_b)
......@@ -28,7 +28,8 @@ from tensorflow.python.ops.gen_linalg_ops import *
@ops.RegisterShape("Cholesky")
def _CholeskyShape(op):
@ops.RegisterShape("MatrixInverse")
def _UnchangedSquare(op):
input_shape = op.inputs[0].get_shape().with_rank(2)
# The matrix must be square.
input_shape[0].assert_is_compatible_with(input_shape[1])
......@@ -36,7 +37,8 @@ def _CholeskyShape(op):
@ops.RegisterShape("BatchCholesky")
def _BatchCholeskyShape(op):
@ops.RegisterShape("BatchMatrixInverse")
def _BatchUnchangedSquare(op):
input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
# The matrices in the batch must be square.
input_shape[-1].assert_is_compatible_with(input_shape[-2])
......@@ -65,22 +67,6 @@ def _BatchMatrixDeterminantShape(op):
return [tensor_shape.unknown_shape()]
@ops.RegisterShape("MatrixInverse")
def _MatrixInverseShape(op):
input_shape = op.inputs[0].get_shape().with_rank(2)
# The matrix must be square.
input_shape[0].assert_is_compatible_with(input_shape[1])
return [input_shape]
@ops.RegisterShape("BatchMatrixInverse")
def _BatchMatrixInverseShape(op):
input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
# The matrices in the batch must be square.
input_shape[-1].assert_is_compatible_with(input_shape[-2])
return [input_shape]
@ops.RegisterShape("SelfAdjointEig")
def _SelfAdjointEigShape(op):
input_shape = op.inputs[0].get_shape().with_rank(2)
......@@ -103,46 +89,25 @@ def _BatchSelfAdjointEigShape(op):
@ops.RegisterShape("MatrixSolve")
def _MatrixSolveShape(op):
lhs_shape = op.inputs[0].get_shape().with_rank(2)
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
# The matrix must be square.
lhs_shape[0].assert_is_compatible_with(lhs_shape[1])
# The matrix and right-hand side must have the same number of rows.
lhs_shape[0].assert_is_compatible_with(rhs_shape[0])
return [[lhs_shape[1], rhs_shape[1]]]
@ops.RegisterShape("BatchMatrixSolve")
def _BatchMatrixSolveShape(op):
lhs_shape = op.inputs[0].get_shape().with_rank_at_least(3)
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(3)
# The matrices must be square.
lhs_shape[-1].assert_is_compatible_with(lhs_shape[-2])
# The matrices and right-hand sides in the batch must have the same number of
# rows.
lhs_shape[-2].assert_is_compatible_with(rhs_shape[-2])
return [lhs_shape[:-1].concatenate(rhs_shape[-1])]
@ops.RegisterShape("MatrixTriangularSolve")
def _MatrixTriangularSolveShape(op):
def _SquareMatrixSolveShape(op):
lhs_shape = op.inputs[0].get_shape().with_rank(2)
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
# The matrix must be square.
lhs_shape[0].assert_is_compatible_with(lhs_shape[1])
# The matrix and righ-hand side must have the same number of rows.
# The matrix and right-hand side must have the same number of rows.
lhs_shape[0].assert_is_compatible_with(rhs_shape[0])
return [rhs_shape]
@ops.RegisterShape("BatchMatrixSolve")
@ops.RegisterShape("BatchMatrixTriangularSolve")
def _BatchMatrixTriangularSolveShape(op):
def _BatchSquareMatrixSolveShape(op):
lhs_shape = op.inputs[0].get_shape().with_rank_at_least(3)
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(3)
# The matrices must be square.
lhs_shape[-1].assert_is_compatible_with(lhs_shape[-2])
# The matrices and righ-hand sides in the batch must have the same number of
# The matrices and right-hand sides in the batch must have the same number of
# rows.
lhs_shape[-2].assert_is_compatible_with(rhs_shape[-2])
return [rhs_shape]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册