未验证 提交 00efdf84 编写于 作者: Z zhangyuqin1998 提交者: GitHub

reorder MatrixRank (#52925)

* reorder MatrixRank

* fix

* fix

* fix

* fix

* fix
上级 069bb2d9
......@@ -755,11 +755,11 @@
backward : matmul_grad
- op : matrix_rank
args : (Tensor x, float tol, bool hermitian=false, bool use_default_tol=true)
args : (Tensor x, float tol, bool use_default_tol=true, bool hermitian=false)
output : Tensor(out)
infer_meta :
func : MatrixRankInferMeta
param : [x, hermitian, use_default_tol]
param : [x, use_default_tol, hermitian]
kernel :
func : matrix_rank
......
......@@ -222,7 +222,7 @@
output : Tensor(out)
infer_meta :
func : MatrixRankStaticInferMeta
param : [x, tol_tensor, hermitian, use_default_tol]
param : [x, tol_tensor, use_default_tol, hermitian]
optional : tol_tensor
kernel :
func : matrix_rank {dense -> dense},
......
......@@ -2207,7 +2207,7 @@ void MatrixRankStaticInferMeta(const MetaTensor& x,
if (atol_tensor) {
MatrixRankTolInferMeta(x, atol_tensor, use_default_tol, hermitian, out);
} else {
MatrixRankInferMeta(x, hermitian, use_default_tol, out);
MatrixRankInferMeta(x, use_default_tol, hermitian, out);
}
}
......
......@@ -2006,8 +2006,8 @@ void LUInferMeta(const MetaTensor& x,
}
void MatrixRankInferMeta(const MetaTensor& x,
bool hermitian,
bool use_default_tol,
bool hermitian,
MetaTensor* out) {
auto dim_x = x.dims();
PADDLE_ENFORCE_GE(dim_x.size(),
......
......@@ -272,8 +272,8 @@ void LUInferMeta(const MetaTensor& x,
void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out);
void MatrixRankInferMeta(const MetaTensor& x,
bool hermitian,
bool use_default_tol,
bool hermitian,
MetaTensor* out);
void MaxOutInferMeta(const MetaTensor& x,
......
......@@ -24,8 +24,8 @@ template <typename T, typename Context>
void MatrixRankKernel(const Context& dev_ctx,
const DenseTensor& x,
float tol,
bool hermitian,
bool use_default_tol,
bool hermitian,
DenseTensor* out) {
DenseTensor atol_tensor;
if (use_default_tol) {
......
......@@ -27,8 +27,8 @@ template <typename T, typename Context>
void MatrixRankKernel(const Context& dev_ctx,
const DenseTensor& x,
float tol,
bool hermitian,
bool use_default_tol,
bool hermitian,
DenseTensor* out) {
DenseTensor atol_tensor;
if (use_default_tol) {
......
......@@ -22,8 +22,8 @@ template <typename T, typename Context>
void MatrixRankKernel(const Context& dev_ctx,
const DenseTensor& x,
float tol,
bool hermitian,
bool use_default_tol,
bool hermitian,
DenseTensor* out);
} // namespace phi
......@@ -34,8 +34,8 @@ KernelSignature MatrixRankOpArgumentMapping(const ArgumentMappingContext& ctx) {
{"X"},
{
"tol",
"hermitian",
"use_default_tol",
"hermitian",
},
{"Out"});
}
......
......@@ -26,7 +26,7 @@ SEED = 2049
np.random.seed(SEED)
def matrix_rank_wraper(x, tol=None, hermitian=False, use_default_tol=True):
def matrix_rank_wraper(x, tol=None, use_default_tol=True, hermitian=False):
return paddle.linalg.matrix_rank(x, tol, hermitian)
......
......@@ -1505,7 +1505,7 @@ def matrix_rank(x, tol=None, hermitian=False, name=None):
else:
tol_attr = float(tol)
use_default_tol = False
return _C_ops.matrix_rank(x, tol_attr, hermitian, use_default_tol)
return _C_ops.matrix_rank(x, tol_attr, use_default_tol, hermitian)
else:
inputs = {}
attrs = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册