未验证 提交 28fffef6 编写于 作者: 0 0x45f 提交者: GitHub

refine matrix_rank op code and doc (#35722)

上级 5548061b
......@@ -27,7 +27,7 @@ namespace operators {
using DDim = framework::DDim;
namespace detail {
static DDim GetInputBatchDim(const DDim& dim_x) {
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
auto x_vec = framework::vectorize(dim_x);
if (x_vec.size() == 2) {
return framework::make_ddim({1});
......@@ -58,11 +58,8 @@ class MatrixRankeOp : public framework::OperatorWithKernel {
"if hermitian == true, matrix should be n*n"));
}
DDim dim_x_batch = detail::GetInputBatchDim(dim_x);
if (ctx->Attrs().Get<bool>(
"use_default_tol")) { // user not input TolTensor and tol
ctx->SetOutputDim("Out", dim_x_batch);
} else if (ctx->HasInput("TolTensor")) {
DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x);
if (ctx->HasInput("TolTensor")) {
auto dim_tol = ctx->GetInputDim("TolTensor");
if (dim_x_batch == dim_tol) {
ctx->SetOutputDim("Out", dim_x_batch);
......@@ -75,9 +72,6 @@ class MatrixRankeOp : public framework::OperatorWithKernel {
GetBroadcastDimsArrays(dim_x_batch, dim_tol, x_batch_dims_array.data(),
tol_dims_array.data(), out_dims_array.data(),
max_dim, axis);
for (auto& it : out_dims_array) {
VLOG(3) << "out dims: " << it;
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dims_array));
}
} else {
......@@ -100,7 +94,9 @@ class MatrixRankeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of matrix_rank op.");
AddInput("TolTensor", "(optional) Tol tensor, shape is same as X batch.")
AddInput("TolTensor",
"(optional) Tol tensor, shape is same as X batch or can broadcast "
"with X batch.")
.AsDispensable();
AddOutput("Out", "(Tensor), The output tensor of matrix_rank op.");
AddAttr<float>("tol", "(float, optional). tol").SetDefault(0.0f);
......
......@@ -36,8 +36,8 @@ class TestMatrixRankOP(OpTest):
self.init_data()
self.inputs = {'X': self.x}
self.attrs = {'hermitian': self.hermitian}
if self.tolTensor is not None:
self.inputs["TolTensor"] = self.tolTensor
if self.tol_tensor is not None:
self.inputs["TolTensor"] = self.tol_tensor
if self.tol is not None:
self.attrs["tol"] = self.tol
self.attrs["use_default_tol"] = self.use_default_tol
......@@ -48,7 +48,7 @@ class TestMatrixRankOP(OpTest):
def init_data(self):
self.x = np.eye(3, dtype=np.float32)
self.tolTensor = None
self.tol_tensor = None
self.tol = 0.1
self.use_default_tol = False
self.hermitian = True
......@@ -58,51 +58,56 @@ class TestMatrixRankOP(OpTest):
class TestMatrixRankOP1(TestMatrixRankOP):
def init_data(self):
self.x = np.eye(3, k=1, dtype=np.float64)
self.tolTensor = None
self.tol_tensor = None
self.tol = None
self.use_default_tol = True
self.hermitian = False
self.out = np.linalg.matrix_rank(self.x, self.tolTensor, self.hermitian)
self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
self.hermitian)
class TestMatrixRankOP2(TestMatrixRankOP):
def init_data(self):
self.x = np.random.rand(3, 4, 5, 6).astype(np.float32)
self.tolTensor = np.random.random([3, 4]).astype(self.x.dtype)
self.tol_tensor = np.random.random([3, 4]).astype(self.x.dtype)
self.tol = None
self.use_default_tol = False
self.hermitian = False
self.out = np.linalg.matrix_rank(self.x, self.tolTensor, self.hermitian)
self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
self.hermitian)
class TestMatrixRankOP3(TestMatrixRankOP):
def init_data(self):
self.x = np.eye(200, dtype=np.float64)
self.tolTensor = None
self.tol_tensor = None
self.tol = None
self.use_default_tol = True
self.hermitian = True
self.out = np.linalg.matrix_rank(self.x, self.tolTensor, self.hermitian)
self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
self.hermitian)
class TestMatrixRankOP4(TestMatrixRankOP):
def init_data(self):
self.x = np.random.rand(1, 10).astype(np.float32)
self.tolTensor = None
self.tol_tensor = None
self.tol = None
self.use_default_tol = True
self.hermitian = False
self.out = np.linalg.matrix_rank(self.x, self.tolTensor, self.hermitian)
self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
self.hermitian)
class TestMatrixRankOP5(TestMatrixRankOP):
def init_data(self):
self.x = np.random.rand(5, 1).astype(np.float64)
self.tolTensor = np.random.random([1, 4]).astype(self.x.dtype)
self.tol_tensor = np.random.random([1, 4]).astype(self.x.dtype)
self.tol = None
self.use_default_tol = False
self.hermitian = False
self.out = np.linalg.matrix_rank(self.x, self.tolTensor, self.hermitian)
self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
self.hermitian)
class TestMatrixRankAPI(unittest.TestCase):
......
......@@ -1106,20 +1106,18 @@ def matrix_rank(x, tol=None, hermitian=False, name=None):
r"""
Computes the rank of a matrix.
The rank of a matrix is the number of singular values that are greater than the specified tol threshold when hermitian=False,
or the number of eigenvalues in absolute value that are greater than the specified tol threshold when hermitian=True.
The rank of a matrix is the number of singular values that are greater than the specified `tol` threshold when hermitian=False,
or the number of eigenvalues in absolute value that are greater than the specified `tol` threshold when hermitian=True.
Args:
x (Tensor): The input tensor.
Its shape should be [..., m, n], where ... is zero or more batch dimensions. If x is a batch of matrices then the output
has the same batch dimensions. The data type of x should be float32 or float64.
tol (float,Tensor,optional): the tolerance value. Default: None.
If tol is not specified, and sigma is the largest singular value (or eigenvalue in absolute value), and eps is the
epsilon value for the dtype of x, then tol is computed with formula tol=sigma * max(m,n) * eps. Note that if x is
a batch of matrices, tol is computed this way for every batch.
hermitian (bool,optional): indicates whether x is Hermitian. Default: False.
When hermitian=True, x is assumed to be Hermitian, but x is not checked inside the function. Instead, We just use the
lower triangular of the matrix to compute.
x (Tensor): The input tensor. Its shape should be `[..., m, n]`, where `...` is zero or more batch dimensions. If `x` is a batch
of matrices then the output has the same batch dimensions. The data type of `x` should be float32 or float64.
tol (float,Tensor,optional): the tolerance value. Default: None. If `tol` is not specified, and `sigma` is the largest
singular value (or eigenvalues in absolute value), and `eps` is the epsilon value for the dtype of `x`, then `tol` is computed
with formula `tol=sigma * max(m,n) * eps`. Note that if `x` is a batch of matrices, `tol` is computed this way for every batch.
hermitian (bool,optional): indicates whether `x` is Hermitian. Default: False. When hermitian=True, `x` is assumed to be Hermitian,
enabling a more efficient method for finding eigenvalues, but `x` is not checked inside the function. Instead, We just use
the lower triangular of the matrix to compute.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册