提交 5819055a 编写于 作者: I Ian Langmore 提交者: TensorFlower Gardener

BUGFIX: LinearOperatorAdjoint.matvec was passing an `adjoint_arg` kwarg, but

this doesn't make sense for matvec (and isn't in the base class matvec defn).
PiperOrigin-RevId: 258417922
上级 26d27c5e
......@@ -161,6 +161,13 @@ class LinearOperatorAdjointTest(
full_matrix1.matmul(
full_matrix2, adjoint=True, adjoint_arg=True).to_dense()))
def test_matvec(self):
matrix = np.array([[1., 2.], [3., 4.]])
x = np.array([1., 2.])
operator = linalg.LinearOperatorFullMatrix(matrix)
self.assertAllClose(matrix.dot(x), self.evaluate(operator.matvec(x)))
self.assertAllClose(matrix.T.dot(x), self.evaluate(operator.H.matvec(x)))
def test_solve_adjoint_operator(self):
matrix1 = self.evaluate(
linear_operator_test_util.random_tril_matrix(
......@@ -223,6 +230,15 @@ class LinearOperatorAdjointTest(
full_matrix1.solve(
full_matrix2, adjoint=True, adjoint_arg=True).to_dense()))
def test_solvevec(self):
matrix = np.array([[1., 2.], [3., 4.]])
inv_matrix = np.linalg.inv(matrix)
x = np.array([1., 2.])
operator = linalg.LinearOperatorFullMatrix(matrix)
self.assertAllClose(inv_matrix.dot(x), self.evaluate(operator.solvevec(x)))
self.assertAllClose(
inv_matrix.T.dot(x), self.evaluate(operator.H.solvevec(x)))
class LinearOperatorAdjointNonSquareTest(
linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
......
......@@ -181,9 +181,8 @@ class LinearOperatorAdjoint(linear_operator.LinearOperator):
return self.operator.matmul(
x, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
def _matvec(self, x, adjoint=False, adjoint_arg=False):
return self.operator.matvec(
x, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
def _matvec(self, x, adjoint=False):
return self.operator.matvec(x, adjoint=(not adjoint))
def _determinant(self):
if self.is_self_adjoint:
......@@ -202,9 +201,8 @@ class LinearOperatorAdjoint(linear_operator.LinearOperator):
return self.operator.solve(
rhs, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
def _solvevec(self, rhs, adjoint=False, adjoint_arg=False):
return self.operator.solvevec(
rhs, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
def _solvevec(self, rhs, adjoint=False):
return self.operator.solvevec(rhs, adjoint=(not adjoint))
def _to_dense(self):
if self.is_self_adjoint:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册