未验证 提交 c77a263d 编写于 作者: R Ruibiao Chen 提交者: GitHub

Add yaml for matrix rank op (#41466)

* modify matrix_rank

* add matrix_rank shape

* add matrix_rank shape

* Add yaml for matrix_rank OP

* Add UT
Co-authored-by: Nzhoujianqian <15205085056@163.com>
上级 5516f180
......@@ -64,6 +64,16 @@ static void BinarySameInputDimsCheck(const MetaTensor& x,
}
}
// Used in MatrixRankTolInferMeta
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
auto x_vec = phi::vectorize(dim_x);
if (x_vec.size() == 2) {
return phi::make_ddim({1});
}
x_vec.erase(x_vec.end() - 2, x_vec.end());
return phi::make_ddim(x_vec);
}
} // namespace detail
void AllValueCompareInferMeta(const MetaTensor& x,
......@@ -1465,6 +1475,47 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x,
out->share_lod(x);
}
void MatrixRankTolInferMeta(const MetaTensor& x,
const MetaTensor& atol_tensor,
bool use_default_tol,
bool hermitian,
MetaTensor* out) {
auto dim_x = x.dims();
PADDLE_ENFORCE_GE(
dim_x.size(),
2,
phi::errors::InvalidArgument("The dims of input must be greater than 2"));
if (hermitian) {
int rows = dim_x[dim_x.size() - 2];
int cols = dim_x[dim_x.size() - 1];
PADDLE_ENFORCE_EQ(rows,
cols,
phi::errors::InvalidArgument(
"if hermitian == true, matrix should be n*n"));
}
DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x);
auto dim_tol = atol_tensor.dims();
if (dim_x_batch == dim_tol) {
out->set_dims(dim_x_batch);
} else {
int max_dim = std::max(dim_x_batch.size(), dim_tol.size());
int axis = std::abs(dim_x_batch.size() - dim_tol.size());
std::vector<int> x_batch_dims_array(max_dim);
std::vector<int> tol_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
phi::funcs::GetBroadcastDimsArrays(dim_x_batch,
dim_tol,
x_batch_dims_array.data(),
tol_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
out->set_dims(phi::make_ddim(out_dims_array));
}
out->share_lod(x);
}
void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) {
auto dim_x = x.dims();
auto dim_vec = vec.dims();
......
......@@ -218,6 +218,12 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x,
int y_num_col_dims,
MetaTensor* out);
void MatrixRankTolInferMeta(const MetaTensor& x,
const MetaTensor& atol_tensor,
bool use_default_tol,
bool hermitian,
MetaTensor* out);
void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out);
void PReluInferMeta(const MetaTensor& x,
......
......@@ -31,6 +31,18 @@ limitations under the License. */
namespace phi {
namespace detail {
// Used in MatrixRankInferMeta
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
auto x_vec = phi::vectorize(dim_x);
if (x_vec.size() == 2) {
return phi::make_ddim({1});
}
x_vec.erase(x_vec.end() - 2, x_vec.end());
return phi::make_ddim(x_vec);
}
} // namespace detail
void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
bool keepdims,
......@@ -901,6 +913,29 @@ void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) {
out->set_dtype(x.dtype());
}
void MatrixRankInferMeta(const MetaTensor& x,
bool use_default_tol,
bool hermitian,
MetaTensor* out) {
auto dim_x = x.dims();
PADDLE_ENFORCE_GE(
dim_x.size(),
2,
phi::errors::InvalidArgument("The dims of input must be greater than 2"));
if (hermitian) {
int rows = dim_x[dim_x.size() - 2];
int cols = dim_x[dim_x.size() - 1];
PADDLE_ENFORCE_EQ(rows,
cols,
phi::errors::InvalidArgument(
"if hermitian == true, matrix should be n*n"));
}
DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x);
out->set_dims(dim_x_batch);
out->share_lod(x);
}
void MaxOutInferMeta(const MetaTensor& x,
int groups,
int axis,
......
......@@ -142,6 +142,11 @@ void LogsumexpInferMeta(const MetaTensor& input,
void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out);
void MatrixRankInferMeta(const MetaTensor& x,
bool use_default_tol,
bool hermitian,
MetaTensor* out);
void MaxOutInferMeta(const MetaTensor& x,
int groups,
int axis,
......
......@@ -30,8 +30,13 @@ SEED = 2049
np.random.seed(SEED)
def matrix_rank_wraper(x, tol=None, use_default_tol=True, hermitian=False):
return paddle.linalg.matrix_rank(x, tol, hermitian)
class TestMatrixRankOP(OpTest):
def setUp(self):
self.python_api = matrix_rank_wraper
self.op_type = "matrix_rank"
self.init_data()
self.inputs = {'X': self.x}
......@@ -44,7 +49,7 @@ class TestMatrixRankOP(OpTest):
self.outputs = {'Out': self.out}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def init_data(self):
self.x = np.eye(3, dtype=np.float32)
......@@ -110,6 +115,28 @@ class TestMatrixRankOP5(TestMatrixRankOP):
self.hermitian)
class TestMatrixRankOP6(TestMatrixRankOP):
def init_data(self):
self.x = np.random.rand(3, 4, 5, 6).astype(np.float32)
self.tol_tensor = None
self.tol = None
self.use_default_tol = False
self.hermitian = False
self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
self.hermitian)
class TestMatrixRankOP7(TestMatrixRankOP):
def init_data(self):
self.x = np.eye(200, dtype=np.float64)
self.tol_tensor = np.random.random([200, 200]).astype(self.x.dtype)
self.tol = None
self.use_default_tol = True
self.hermitian = True
self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
self.hermitian)
class TestMatrixRankAPI(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
......
......@@ -1284,8 +1284,26 @@ def matrix_rank(x, tol=None, hermitian=False, name=None):
# [1, 1, 1, 1]]
"""
if in_dygraph_mode():
if isinstance(tol, Variable):
if tol.dtype != x.dtype:
tol_tensor = cast(tol, x.dtype)
else:
tol_tensor = tol
use_default_tol = False
return _C_ops.final_state_matrix_rank_tol(
x, tol_tensor, use_default_tol, hermitian)
if paddle.in_dynamic_mode():
if tol is None:
tol_attr = 0.0
use_default_tol = True
else:
tol_attr = float(tol)
use_default_tol = False
return _C_ops.final_state_matrix_rank(x, tol_attr, use_default_tol,
hermitian)
if _in_legacy_dygraph():
if tol is None:
tol_tensor = None
tol_attr = 0.0
......
......@@ -1157,6 +1157,23 @@
func : matrix_power
backward : matrix_power_grad
- api : matrix_rank
args : (Tensor x, float tol, bool use_default_tol=true, bool hermitian=false)
output : Tensor(out)
infer_meta :
func : MatrixRankInferMeta
param : [x, use_default_tol, hermitian]
kernel :
func : matrix_rank
- api : matrix_rank_tol
args : (Tensor x, Tensor atol_tensor, bool use_default_tol=true, bool hermitian=false)
output : Tensor(out)
infer_meta :
func : MatrixRankTolInferMeta
kernel :
func : matrix_rank_tol
- api : max
args : (Tensor x, int64_t[] dims={}, bool keep_dim=false)
output : Tensor(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册