Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
657b6742
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
657b6742
编写于
1月 10, 2022
作者:
Y
Yulong Ao
提交者:
GitHub
1月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add the backward support for QR (#38824)
* Add the backward support for QR * Remove unnecessary comments
上级
953638e0
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
347 addition
and
3 deletion
+347
-3
paddle/fluid/operators/qr_op.h
paddle/fluid/operators/qr_op.h
+121
-2
paddle/fluid/operators/svd_helper.h
paddle/fluid/operators/svd_helper.h
+135
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/test_qr_op.py
python/paddle/fluid/tests/unittests/test_qr_op.py
+90
-1
未找到文件。
paddle/fluid/operators/qr_op.h
浏览文件 @
657b6742
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -79,9 +80,11 @@ class QrCPUKernel : public framework::OpKernel<T> {
...
@@ -79,9 +80,11 @@ class QrCPUKernel : public framework::OpKernel<T> {
q_data
=
q
.
mutable_data
<
math
::
Real
<
T
>>
(
q_data
=
q
.
mutable_data
<
math
::
Real
<
T
>>
(
context
.
GetPlace
(),
context
.
GetPlace
(),
size_t
(
batch_size
*
m
*
k
*
sizeof
(
math
::
Real
<
T
>
)));
size_t
(
batch_size
*
m
*
k
*
sizeof
(
math
::
Real
<
T
>
)));
memset
(
q_data
,
0
,
size_t
(
batch_size
*
m
*
k
*
sizeof
(
math
::
Real
<
T
>
)));
}
}
auto
*
r_data
=
r
.
mutable_data
<
math
::
Real
<
T
>>
(
auto
*
r_data
=
r
.
mutable_data
<
math
::
Real
<
T
>>
(
context
.
GetPlace
(),
size_t
(
batch_size
*
k
*
n
*
sizeof
(
math
::
Real
<
T
>
)));
context
.
GetPlace
(),
size_t
(
batch_size
*
k
*
n
*
sizeof
(
math
::
Real
<
T
>
)));
memset
(
r_data
,
0
,
size_t
(
batch_size
*
k
*
n
*
sizeof
(
math
::
Real
<
T
>
)));
// Implement QR by calling Eigen
// Implement QR by calling Eigen
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
...
@@ -126,8 +129,124 @@ template <typename DeviceContext, typename T>
...
@@ -126,8 +129,124 @@ template <typename DeviceContext, typename T>
class
QrGradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
QrGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
const
framework
::
Tensor
&
Q
=
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Q"
);
"QR doesn't have the backward kernel now and will be supported soon."
));
const
framework
::
Tensor
&
R
=
*
ctx
.
Input
<
framework
::
Tensor
>
(
"R"
);
// Use a different name A instead of X
const
framework
::
Tensor
&
A
=
*
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
const
framework
::
Tensor
&
dQ
=
*
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Q"
));
const
framework
::
Tensor
&
dR
=
*
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"R"
));
// Use a different name dA instead of dX
framework
::
Tensor
&
dA
=
*
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
dA
.
mutable_data
<
math
::
Real
<
T
>>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
math
::
SetConstant
<
DeviceContext
,
T
>
()(
dev_ctx
,
&
dA
,
T
(
0
));
auto
dito
=
math
::
DeviceIndependenceTensorOperations
<
DeviceContext
,
T
>
(
ctx
);
std
::
string
mode
=
ctx
.
Attr
<
std
::
string
>
(
"mode"
);
bool
compute_q
,
reduced
;
std
::
tie
(
compute_q
,
reduced
)
=
_parse_qr_mode
(
mode
);
if
(
!
compute_q
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"The derivative of qr is not implemented when mode='r'."
));
}
auto
a_dims
=
A
.
dims
();
int
a_rank
=
a_dims
.
size
();
int
m
=
a_dims
[
a_rank
-
2
];
int
n
=
a_dims
[
a_rank
-
1
];
if
((
m
>
n
)
&&
(
!
reduced
))
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"The derivative of qr is not implemented when mode='complete' and "
"nrows > ncols."
));
}
// m >= n case
auto
m_gt_n_case
=
[](
const
framework
::
ExecutionContext
&
ctx
,
math
::
DeviceIndependenceTensorOperations
<
DeviceContext
,
T
>&
dito
,
const
Tensor
&
dQ
,
const
Tensor
&
dR
,
const
Tensor
&
A
,
const
Tensor
&
Q
,
const
Tensor
&
R
)
->
framework
::
Tensor
{
// Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable
// Programming Tensor Networks.
// https://arxiv.org/abs/1903.09650 Section 3. QR factorization
// dR^H
framework
::
Tensor
R_term
;
if
(
ctx
.
HasInput
(
framework
::
GradVarName
(
"R"
)))
{
R_term
=
dito
.
Matmul
(
R
,
dito
.
Transpose
(
dR
));
}
else
{
R_term
=
dito
.
Fill
(
framework
::
vectorize
<
int
>
(
R
.
dims
()),
0
);
}
// dQ^H * Q
framework
::
Tensor
Q_term
;
if
(
ctx
.
HasInput
(
framework
::
GradVarName
(
"Q"
)))
{
Q_term
=
dito
.
Matmul
(
dito
.
Transpose
(
dQ
),
Q
);
}
else
{
Q_term
=
dito
.
Fill
(
framework
::
vectorize
<
int
>
(
R
.
dims
()),
0
);
}
framework
::
Tensor
M_tmp1
=
dito
.
Sub
(
R_term
,
Q_term
);
// Compute M = (tril(M) + tril(M).mH()) * 0.5 Identity
framework
::
Tensor
M_tril_0
=
dito
.
TrilTriu
(
M_tmp1
,
0
,
true
);
framework
::
Tensor
M_tril_1
=
dito
.
TrilTriu
(
M_tmp1
,
-
1
,
true
);
framework
::
Tensor
M
=
dito
.
Add
(
M_tril_0
,
dito
.
Transpose
(
M_tril_1
));
framework
::
Tensor
rhs_term
;
if
(
ctx
.
HasInput
(
framework
::
GradVarName
(
"Q"
)))
{
rhs_term
=
dito
.
Add
(
dQ
,
dito
.
Matmul
(
Q
,
M
));
}
else
{
rhs_term
=
dito
.
Matmul
(
Q
,
M
);
}
// dA * R^H = rhs_term
auto
dA
=
dito
.
TriangularSolve
(
dito
.
Transpose
(
dito
.
Conj
(
dito
.
Transpose
(
R
))),
dito
.
Transpose
(
rhs_term
),
/*upper=*/
true
,
/*transpose=*/
false
,
/*unitriangular=*/
false
);
return
dito
.
Transpose
(
dA
);
};
if
(
m
>=
n
)
{
auto
dA_tmp
=
m_gt_n_case
(
ctx
,
dito
,
dQ
,
dR
,
A
,
Q
,
R
);
framework
::
TensorCopy
(
dA_tmp
,
dA
.
place
(),
&
dA
);
}
else
{
// If m < n for input matrices A, we partition A = [X|Y] and R = [U|V]
// Calculate dX and dY individually and concatenate them to get dA
dA
.
mutable_data
<
math
::
Real
<
T
>>
(
ctx
.
GetPlace
());
auto
Y
=
dito
.
Slice
(
A
,
{
-
1
},
{
m
},
{
n
});
auto
U
=
dito
.
Slice
(
R
,
{
-
1
},
{
0
},
{
m
});
framework
::
Tensor
dY
,
dX
,
dV
,
dR_tmp
,
dQ_prime
;
if
(
ctx
.
HasInput
(
framework
::
GradVarName
(
"R"
)))
{
dV
=
dito
.
Slice
(
dR
,
{
-
1
},
{
m
},
{
n
});
dR_tmp
=
dito
.
Slice
(
dR
,
{
-
1
},
{
0
},
{
m
});
// Y * dV^H
dQ_prime
=
dito
.
Matmul
(
Y
,
dito
.
Transpose
(
dV
));
}
else
{
dV
=
dito
.
Fill
(
framework
::
vectorize
<
int
>
(
Y
.
dims
()),
0
);
dQ_prime
=
dito
.
Fill
(
framework
::
vectorize
<
int
>
(
Q
.
dims
()),
0
);
}
if
(
ctx
.
HasInput
(
framework
::
GradVarName
(
"Q"
)))
{
dQ_prime
=
dito
.
Add
(
dQ_prime
,
dQ
);
}
dX
=
m_gt_n_case
(
ctx
,
dito
,
dQ_prime
,
dR_tmp
,
A
,
Q
,
U
);
dY
=
dito
.
Matmul
(
Q
,
dV
);
// Concatenate dX and dY to get dA.
auto
dA_tmp
=
dito
.
ConcatTwoTensors
(
dX
,
dY
,
-
1
);
framework
::
TensorCopy
(
dA_tmp
,
dA
.
place
(),
&
dA
);
}
}
}
};
};
...
...
paddle/fluid/operators/svd_helper.h
浏览文件 @
657b6742
...
@@ -146,6 +146,93 @@ static std::vector<int> GetBroadcastShape(InTensors ins) {
...
@@ -146,6 +146,93 @@ static std::vector<int> GetBroadcastShape(InTensors ins) {
return
broadcast_shape
;
return
broadcast_shape
;
}
}
static
inline
framework
::
DDim
ComputeAndCheckShapeForConcatOp
(
const
bool
is_runtime
,
const
std
::
vector
<
framework
::
DDim
>&
inputs_dims
,
const
size_t
axis
)
{
const
size_t
n
=
inputs_dims
.
size
();
auto
out_dims
=
inputs_dims
[
0
];
size_t
in_zero_dims_size
=
out_dims
.
size
();
for
(
size_t
i
=
1
;
i
<
n
;
i
++
)
{
PADDLE_ENFORCE_EQ
(
inputs_dims
[
i
].
size
(),
out_dims
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The shape of input[0] and input[%d] "
"is expected to be equal."
"But received input[0]'s shape = "
"[%s], input[%d]'s shape = [%s]."
,
i
,
inputs_dims
[
0
],
i
,
inputs_dims
[
i
]));
for
(
size_t
j
=
0
;
j
<
in_zero_dims_size
;
j
++
)
{
if
(
j
==
axis
)
{
if
(
is_runtime
)
{
out_dims
[
axis
]
+=
inputs_dims
[
i
][
j
];
}
else
{
if
(
inputs_dims
[
i
][
j
]
==
-
1
||
out_dims
[
j
]
==
-
1
)
{
out_dims
[
axis
]
=
-
1
;
}
else
{
out_dims
[
axis
]
+=
inputs_dims
[
i
][
j
];
}
}
}
else
{
bool
check_shape
=
is_runtime
||
(
inputs_dims
[
0
][
j
]
>
0
&&
inputs_dims
[
i
][
j
]
>
0
);
if
(
check_shape
)
{
// check all shape in run time
PADDLE_ENFORCE_EQ
(
inputs_dims
[
0
][
j
],
inputs_dims
[
i
][
j
],
platform
::
errors
::
InvalidArgument
(
"The %d-th dimension of input[0] and input[%d] "
"is expected to be equal."
"But received input[0]'s shape = "
"[%s], input[%d]'s shape = [%s]."
,
j
,
i
,
inputs_dims
[
0
],
i
,
inputs_dims
[
i
]));
}
if
(
!
is_runtime
&&
out_dims
[
j
]
==
-
1
&&
inputs_dims
[
i
][
j
]
>
0
)
{
out_dims
[
j
]
=
inputs_dims
[
i
][
j
];
}
}
}
}
return
out_dims
;
}
static
inline
int64_t
ComputeAxisForConcatOp
(
int64_t
axis
,
int64_t
rank
)
{
PADDLE_ENFORCE_EQ
(
axis
>=
-
rank
&&
axis
<
rank
,
true
,
platform
::
errors
::
InvalidArgument
(
"The axis is expected to be in range of [%d, %d), but got %d"
,
-
rank
,
rank
,
axis
));
if
(
axis
<
0
)
{
axis
=
axis
+
rank
;
}
return
axis
>
0
?
axis
:
0
;
}
// Prepared for the broadcast operation
static
std
::
vector
<
int64_t
>
get_broadcast_batch_portion
(
std
::
vector
<
int64_t
>
x
,
std
::
vector
<
int64_t
>
y
)
{
size_t
size_x
=
x
.
size
();
size_t
size_y
=
y
.
size
();
size_t
size
=
std
::
max
(
size_x
,
size_y
);
std
::
vector
<
int64_t
>
batchPortion
(
size
);
ptrdiff_t
i
=
(
ptrdiff_t
)
size
-
1
;
for
(;
i
>=
0
;
--
i
)
{
ptrdiff_t
offset
=
size
-
i
-
1
;
ptrdiff_t
dim_x
=
size_x
-
offset
-
1
;
ptrdiff_t
dim_y
=
size_y
-
offset
-
1
;
int64_t
x_size
=
(
dim_x
>=
0
)
?
x
[
dim_x
]
:
1
;
int64_t
y_size
=
(
dim_y
>=
0
)
?
y
[
dim_y
]
:
1
;
PADDLE_ENFORCE_EQ
(
(
x_size
==
y_size
||
x_size
==
1
||
y_size
==
1
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"The size of tensor x (%d) must match the size of tensor y "
"(%d) at non-singleton dimension %d."
,
x_size
,
y_size
,
i
));
batchPortion
[
i
]
=
x_size
!=
1
?
x_size
:
y_size
;
}
return
batchPortion
;
}
#define DITO_TRANSPOSE_RANK_CASE(N) \
#define DITO_TRANSPOSE_RANK_CASE(N) \
case N: { \
case N: { \
math::Transpose<DeviceContext, T, N> trans; \
math::Transpose<DeviceContext, T, N> trans; \
...
@@ -515,6 +602,54 @@ struct DeviceIndependenceTensorOperations {
...
@@ -515,6 +602,54 @@ struct DeviceIndependenceTensorOperations {
return
CreateOpRunAndReturnTensor
(
"tril_triu"
,
inputs
,
attrs
,
out_shape
);
return
CreateOpRunAndReturnTensor
(
"tril_triu"
,
inputs
,
attrs
,
out_shape
);
}
}
framework
::
Tensor
TriangularSolve
(
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
y
,
bool
upper
,
bool
transpose
,
bool
unitriangular
)
{
framework
::
AttributeMap
attrs
;
attrs
[
"upper"
]
=
upper
;
attrs
[
"transpose"
]
=
transpose
;
attrs
[
"unitriangular"
]
=
unitriangular
;
NameInTensorMap
inputs
({{
"X"
,
{
&
x
}},
{
"Y"
,
{
&
y
}}});
auto
x_dims
=
x
.
dims
();
auto
y_dims
=
y
.
dims
();
auto
y_dims_n
=
y_dims
.
size
();
std
::
vector
<
int64_t
>
x_dims_vec
=
paddle
::
framework
::
vectorize
<
int64_t
>
(
x_dims
);
std
::
vector
<
int64_t
>
y_dims_vec
=
paddle
::
framework
::
vectorize
<
int64_t
>
(
y_dims
);
std
::
vector
<
int64_t
>
x_dims_vec_cut
(
x_dims_vec
.
begin
(),
x_dims_vec
.
end
()
-
2
);
std
::
vector
<
int64_t
>
y_dims_vec_cut
(
y_dims_vec
.
begin
(),
y_dims_vec
.
end
()
-
2
);
std
::
vector
<
int64_t
>
expand_batch_portion
=
get_broadcast_batch_portion
(
x_dims_vec_cut
,
y_dims_vec_cut
);
std
::
vector
<
int64_t
>
y_broadcast_dims
({
expand_batch_portion
});
y_broadcast_dims
.
insert
(
y_broadcast_dims
.
end
(),
{
y_dims_vec
[
y_dims_n
-
2
],
y_dims_vec
[
y_dims_n
-
1
]});
std
::
vector
<
int
>
out_shape
(
y_broadcast_dims
.
begin
(),
y_broadcast_dims
.
end
());
return
CreateOpRunAndReturnTensor
(
"triangular_solve"
,
inputs
,
attrs
,
out_shape
);
}
framework
::
Tensor
ConcatTwoTensors
(
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
y
,
int
axis
)
{
framework
::
AttributeMap
attrs
;
attrs
[
"axis"
]
=
axis
;
std
::
vector
<
framework
::
DDim
>
inputs_dims
({
x
.
dims
(),
y
.
dims
()});
NameInTensorMap
inputs
({{
"X"
,
{
&
x
,
&
y
}}});
size_t
axis_
=
ComputeAxisForConcatOp
(
static_cast
<
int64_t
>
(
axis
),
static_cast
<
int64_t
>
(
inputs_dims
[
0
].
size
()));
framework
::
DDim
out_dims
=
ComputeAndCheckShapeForConcatOp
(
true
,
inputs_dims
,
axis_
);
if
(
out_dims
[
axis_
]
<
0
)
{
out_dims
[
axis_
]
=
-
1
;
}
std
::
vector
<
int
>
out_shape
=
framework
::
vectorize
<
int
>
(
out_dims
);
return
CreateOpRunAndReturnTensor
(
"concat"
,
inputs
,
attrs
,
out_shape
);
}
Tensor
Conj
(
const
Tensor
&
x
)
{
Tensor
Conj
(
const
Tensor
&
x
)
{
Tensor
out
;
Tensor
out
;
auto
*
out_data
=
out
.
mutable_data
<
T
>
(
x
.
dims
(),
context
.
GetPlace
());
auto
*
out_data
=
out
.
mutable_data
<
T
>
(
x
.
dims
(),
context
.
GetPlace
());
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
657b6742
...
@@ -975,6 +975,7 @@ set_tests_properties(test_lstm_cudnn_op PROPERTIES TIMEOUT 120)
...
@@ -975,6 +975,7 @@ set_tests_properties(test_lstm_cudnn_op PROPERTIES TIMEOUT 120)
set_tests_properties
(
test_stack_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_stack_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_bilinear_interp_v2_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_bilinear_interp_v2_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_svd_op PROPERTIES TIMEOUT 80
)
set_tests_properties
(
test_svd_op PROPERTIES TIMEOUT 80
)
set_tests_properties
(
test_qr_op PROPERTIES TIMEOUT 60
)
set_tests_properties
(
test_deformable_psroi_pooling PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_deformable_psroi_pooling PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_trilinear_interp_v2_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_trilinear_interp_v2_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_imperative_static_runner_mnist PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_imperative_static_runner_mnist PROPERTIES TIMEOUT 120
)
...
...
python/paddle/fluid/tests/unittests/test_qr_op.py
浏览文件 @
657b6742
...
@@ -21,6 +21,96 @@ import paddle
...
@@ -21,6 +21,96 @@ import paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
layers
import
paddle.fluid.layers
as
layers
import
paddle.fluid.core
as
core
import
paddle.fluid.core
as
core
from
op_test
import
OpTest
class
TestQrOp
(
OpTest
):
def
setUp
(
self
):
paddle
.
enable_static
()
np
.
random
.
seed
(
4
)
self
.
op_type
=
"qr"
a
,
q
,
r
=
self
.
get_input_and_output
()
self
.
inputs
=
{
"X"
:
a
}
self
.
attrs
=
{
"mode"
:
self
.
get_mode
()}
self
.
outputs
=
{
"Q"
:
q
,
"R"
:
r
}
def
get_dtype
(
self
):
return
"float64"
def
get_mode
(
self
):
return
"reduced"
def
get_shape
(
self
):
return
(
11
,
11
)
def
get_input_and_output
(
self
):
dtype
=
self
.
get_dtype
()
shape
=
self
.
get_shape
()
mode
=
self
.
get_mode
()
assert
mode
!=
"r"
,
"Cannot be backward in r mode."
a
=
np
.
random
.
rand
(
*
shape
).
astype
(
dtype
)
m
=
a
.
shape
[
-
2
]
n
=
a
.
shape
[
-
1
]
min_mn
=
min
(
m
,
n
)
if
mode
==
"reduced"
:
k
=
min_mn
else
:
k
=
m
q_shape
=
list
(
a
.
shape
[:
-
2
])
q_shape
.
extend
([
m
,
k
])
r_shape
=
list
(
a
.
shape
[:
-
2
])
r_shape
.
extend
([
k
,
n
])
q
=
np
.
zeros
(
q_shape
).
astype
(
dtype
)
r
=
np
.
zeros
(
r_shape
).
astype
(
dtype
)
batch_size
=
a
.
size
//
(
a
.
shape
[
-
1
]
*
a
.
shape
[
-
2
])
for
i
in
range
(
batch_size
):
coord
=
np
.
unravel_index
(
i
,
a
.
shape
[:
-
2
])
tmp_q
,
tmp_r
=
np
.
linalg
.
qr
(
a
[
coord
],
mode
=
mode
)
q
[
coord
]
=
tmp_q
r
[
coord
]
=
tmp_r
return
a
,
q
,
r
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad_normal
(
self
):
self
.
check_grad
([
'X'
],
[
'Q'
,
'R'
])
class
TestQrOpCase1
(
TestQrOp
):
def
get_shape
(
self
):
return
(
10
,
12
)
class
TestQrOpCase2
(
TestQrOp
):
def
get_shape
(
self
):
return
(
16
,
15
)
class
TestQrOpCase3
(
TestQrOp
):
def
get_shape
(
self
):
return
(
2
,
12
,
16
)
class
TestQrOpCase4
(
TestQrOp
):
def
get_shape
(
self
):
return
(
3
,
16
,
15
)
class
TestQrOpCase5
(
TestQrOp
):
def
get_mode
(
self
):
return
"complete"
def
get_shape
(
self
):
return
(
10
,
12
)
class
TestQrOpCase6
(
TestQrOp
):
def
get_mode
(
self
):
return
"complete"
def
get_shape
(
self
):
return
(
2
,
10
,
12
)
class
TestQrAPI
(
unittest
.
TestCase
):
class
TestQrAPI
(
unittest
.
TestCase
):
...
@@ -169,5 +259,4 @@ class TestQrAPI(unittest.TestCase):
...
@@ -169,5 +259,4 @@ class TestQrAPI(unittest.TestCase):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
paddle
.
enable_static
()
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录