未验证 提交 28fffef6 编写于 作者: 0 0x45f 提交者: GitHub

refine matrix_rank op code and doc (#35722)

上级 5548061b
...@@ -27,7 +27,7 @@ namespace operators { ...@@ -27,7 +27,7 @@ namespace operators {
using DDim = framework::DDim; using DDim = framework::DDim;
namespace detail { namespace detail {
static DDim GetInputBatchDim(const DDim& dim_x) { static DDim CheckAndGetOutputDim(const DDim& dim_x) {
auto x_vec = framework::vectorize(dim_x); auto x_vec = framework::vectorize(dim_x);
if (x_vec.size() == 2) { if (x_vec.size() == 2) {
return framework::make_ddim({1}); return framework::make_ddim({1});
...@@ -58,11 +58,8 @@ class MatrixRankeOp : public framework::OperatorWithKernel { ...@@ -58,11 +58,8 @@ class MatrixRankeOp : public framework::OperatorWithKernel {
"if hermitian == true, matrix should be n*n")); "if hermitian == true, matrix should be n*n"));
} }
DDim dim_x_batch = detail::GetInputBatchDim(dim_x); DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x);
if (ctx->Attrs().Get<bool>( if (ctx->HasInput("TolTensor")) {
"use_default_tol")) { // user not input TolTensor and tol
ctx->SetOutputDim("Out", dim_x_batch);
} else if (ctx->HasInput("TolTensor")) {
auto dim_tol = ctx->GetInputDim("TolTensor"); auto dim_tol = ctx->GetInputDim("TolTensor");
if (dim_x_batch == dim_tol) { if (dim_x_batch == dim_tol) {
ctx->SetOutputDim("Out", dim_x_batch); ctx->SetOutputDim("Out", dim_x_batch);
...@@ -75,9 +72,6 @@ class MatrixRankeOp : public framework::OperatorWithKernel { ...@@ -75,9 +72,6 @@ class MatrixRankeOp : public framework::OperatorWithKernel {
GetBroadcastDimsArrays(dim_x_batch, dim_tol, x_batch_dims_array.data(), GetBroadcastDimsArrays(dim_x_batch, dim_tol, x_batch_dims_array.data(),
tol_dims_array.data(), out_dims_array.data(), tol_dims_array.data(), out_dims_array.data(),
max_dim, axis); max_dim, axis);
for (auto& it : out_dims_array) {
VLOG(3) << "out dims: " << it;
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dims_array)); ctx->SetOutputDim("Out", framework::make_ddim(out_dims_array));
} }
} else { } else {
...@@ -100,7 +94,9 @@ class MatrixRankeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -100,7 +94,9 @@ class MatrixRankeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor), The input tensor of matrix_rank op."); AddInput("X", "(Tensor), The input tensor of matrix_rank op.");
AddInput("TolTensor", "(optional) Tol tensor, shape is same as X batch.") AddInput("TolTensor",
"(optional) Tol tensor, shape is same as X batch or can broadcast "
"with X batch.")
.AsDispensable(); .AsDispensable();
AddOutput("Out", "(Tensor), The output tensor of matrix_rank op."); AddOutput("Out", "(Tensor), The output tensor of matrix_rank op.");
AddAttr<float>("tol", "(float, optional). tol").SetDefault(0.0f); AddAttr<float>("tol", "(float, optional). tol").SetDefault(0.0f);
......
...@@ -36,8 +36,8 @@ class TestMatrixRankOP(OpTest): ...@@ -36,8 +36,8 @@ class TestMatrixRankOP(OpTest):
self.init_data() self.init_data()
self.inputs = {'X': self.x} self.inputs = {'X': self.x}
self.attrs = {'hermitian': self.hermitian} self.attrs = {'hermitian': self.hermitian}
if self.tolTensor is not None: if self.tol_tensor is not None:
self.inputs["TolTensor"] = self.tolTensor self.inputs["TolTensor"] = self.tol_tensor
if self.tol is not None: if self.tol is not None:
self.attrs["tol"] = self.tol self.attrs["tol"] = self.tol
self.attrs["use_default_tol"] = self.use_default_tol self.attrs["use_default_tol"] = self.use_default_tol
...@@ -48,7 +48,7 @@ class TestMatrixRankOP(OpTest): ...@@ -48,7 +48,7 @@ class TestMatrixRankOP(OpTest):
def init_data(self): def init_data(self):
self.x = np.eye(3, dtype=np.float32) self.x = np.eye(3, dtype=np.float32)
self.tolTensor = None self.tol_tensor = None
self.tol = 0.1 self.tol = 0.1
self.use_default_tol = False self.use_default_tol = False
self.hermitian = True self.hermitian = True
...@@ -58,51 +58,56 @@ class TestMatrixRankOP(OpTest): ...@@ -58,51 +58,56 @@ class TestMatrixRankOP(OpTest):
class TestMatrixRankOP1(TestMatrixRankOP): class TestMatrixRankOP1(TestMatrixRankOP):
def init_data(self): def init_data(self):
self.x = np.eye(3, k=1, dtype=np.float64) self.x = np.eye(3, k=1, dtype=np.float64)
self.tolTensor = None self.tol_tensor = None
self.tol = None self.tol = None
self.use_default_tol = True self.use_default_tol = True
self.hermitian = False self.hermitian = False
self.out = np.linalg.matrix_rank(self.x, self.tolTensor, self.hermitian) self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
self.hermitian)
class TestMatrixRankOP2(TestMatrixRankOP): class TestMatrixRankOP2(TestMatrixRankOP):
def init_data(self): def init_data(self):
self.x = np.random.rand(3, 4, 5, 6).astype(np.float32) self.x = np.random.rand(3, 4, 5, 6).astype(np.float32)
self.tolTensor = np.random.random([3, 4]).astype(self.x.dtype) self.tol_tensor = np.random.random([3, 4]).astype(self.x.dtype)
self.tol = None self.tol = None
self.use_default_tol = False self.use_default_tol = False
self.hermitian = False self.hermitian = False
self.out = np.linalg.matrix_rank(self.x, self.tolTensor, self.hermitian) self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
self.hermitian)
class TestMatrixRankOP3(TestMatrixRankOP): class TestMatrixRankOP3(TestMatrixRankOP):
def init_data(self): def init_data(self):
self.x = np.eye(200, dtype=np.float64) self.x = np.eye(200, dtype=np.float64)
self.tolTensor = None self.tol_tensor = None
self.tol = None self.tol = None
self.use_default_tol = True self.use_default_tol = True
self.hermitian = True self.hermitian = True
self.out = np.linalg.matrix_rank(self.x, self.tolTensor, self.hermitian) self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
self.hermitian)
class TestMatrixRankOP4(TestMatrixRankOP): class TestMatrixRankOP4(TestMatrixRankOP):
def init_data(self): def init_data(self):
self.x = np.random.rand(1, 10).astype(np.float32) self.x = np.random.rand(1, 10).astype(np.float32)
self.tolTensor = None self.tol_tensor = None
self.tol = None self.tol = None
self.use_default_tol = True self.use_default_tol = True
self.hermitian = False self.hermitian = False
self.out = np.linalg.matrix_rank(self.x, self.tolTensor, self.hermitian) self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
self.hermitian)
class TestMatrixRankOP5(TestMatrixRankOP): class TestMatrixRankOP5(TestMatrixRankOP):
def init_data(self): def init_data(self):
self.x = np.random.rand(5, 1).astype(np.float64) self.x = np.random.rand(5, 1).astype(np.float64)
self.tolTensor = np.random.random([1, 4]).astype(self.x.dtype) self.tol_tensor = np.random.random([1, 4]).astype(self.x.dtype)
self.tol = None self.tol = None
self.use_default_tol = False self.use_default_tol = False
self.hermitian = False self.hermitian = False
self.out = np.linalg.matrix_rank(self.x, self.tolTensor, self.hermitian) self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
self.hermitian)
class TestMatrixRankAPI(unittest.TestCase): class TestMatrixRankAPI(unittest.TestCase):
......
...@@ -1106,20 +1106,18 @@ def matrix_rank(x, tol=None, hermitian=False, name=None): ...@@ -1106,20 +1106,18 @@ def matrix_rank(x, tol=None, hermitian=False, name=None):
r""" r"""
Computes the rank of a matrix. Computes the rank of a matrix.
The rank of a matrix is the number of singular values that are greater than the specified tol threshold when hermitian=False, The rank of a matrix is the number of singular values that are greater than the specified `tol` threshold when hermitian=False,
or the number of eigenvalues in absolute value that are greater than the specified tol threshold when hermitian=True. or the number of eigenvalues in absolute value that are greater than the specified `tol` threshold when hermitian=True.
Args: Args:
x (Tensor): The input tensor. x (Tensor): The input tensor. Its shape should be `[..., m, n]`, where `...` is zero or more batch dimensions. If `x` is a batch
Its shape should be [..., m, n], where ... is zero or more batch dimensions. If x is a batch of matrices then the output of matrices then the output has the same batch dimensions. The data type of `x` should be float32 or float64.
has the same batch dimensions. The data type of x should be float32 or float64. tol (float,Tensor,optional): the tolerance value. Default: None. If `tol` is not specified, and `sigma` is the largest
tol (float,Tensor,optional): the tolerance value. Default: None. singular value (or eigenvalues in absolute value), and `eps` is the epsilon value for the dtype of `x`, then `tol` is computed
If tol is not specified, and sigma is the largest singular value (or eigenvalue in absolute value), and eps is the with formula `tol=sigma * max(m,n) * eps`. Note that if `x` is a batch of matrices, `tol` is computed this way for every batch.
epsilon value for the dtype of x, then tol is computed with formula tol=sigma * max(m,n) * eps. Note that if x is hermitian (bool,optional): indicates whether `x` is Hermitian. Default: False. When hermitian=True, `x` is assumed to be Hermitian,
a batch of matrices, tol is computed this way for every batch. enabling a more efficient method for finding eigenvalues, but `x` is not checked inside the function. Instead, We just use
hermitian (bool,optional): indicates whether x is Hermitian. Default: False. the lower triangular of the matrix to compute.
When hermitian=True, x is assumed to be Hermitian, but x is not checked inside the function. Instead, We just use the
lower triangular of the matrix to compute.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册