Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
2af286a6
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2af286a6
编写于
7月 13, 2022
作者:
H
Haohongxiang
提交者:
GitHub
7月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs of paddle.linalg.lstsq (#44280)
上级
7cf72a38
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
39 addition
and
10 deletion
+39
-10
paddle/fluid/operators/lstsq_op.cu
paddle/fluid/operators/lstsq_op.cu
+16
-7
python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py
python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py
+23
-3
未找到文件。
paddle/fluid/operators/lstsq_op.cu
浏览文件 @
2af286a6
...
...
@@ -100,7 +100,7 @@ class LstsqCUDAKernel : public framework::OpKernel<T> {
true
,
batch_count
,
m
,
n
,
n
rhs
,
k
,
x_data
,
x_stride
,
...
...
@@ -137,14 +137,17 @@ class LstsqCUDAKernel : public framework::OpKernel<T> {
// Step 2, solve R^H Z = Y
Tensor
trans_r
=
dito
.
Transpose
(
new_x
);
Tensor
slice_r
=
dito
.
Slice
(
trans_r
,
{
-
2
},
{
0
},
{
min_mn
});
Tensor
res_r
=
dito
.
TrilTriu
(
slice_r
,
0
,
false
);
phi
::
TriangularSolveKernel
<
T
,
Context
>
(
phi_dev_ctx
,
tran
s_r
,
new_y
,
true
,
true
,
false
,
solution
);
phi_dev_ctx
,
re
s_r
,
new_y
,
true
,
true
,
false
,
solution
);
// Step 3, X <- Q Z
BatchedOrgqr
<
DeviceContext
,
T
>
(
dev_ctx
,
batch_count
,
n
,
n
,
m
,
min_mn
,
x_data
,
n
,
...
...
@@ -183,8 +186,6 @@ void BatchedOrmqr<platform::CUDADeviceContext, float>(
auto
handle
=
dev_ctx
.
cusolver_dn_handle
();
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusolverDnSormqr_bufferSize
(
handle
,
side
,
trans
,
m
,
n
,
k
,
a
,
lda
,
tau
,
other
,
ldc
,
&
lwork
));
auto
workspace
=
memory
::
Alloc
(
dev_ctx
,
lwork
*
sizeof
(
float
));
float
*
workspace_ptr
=
reinterpret_cast
<
float
*>
(
workspace
->
ptr
());
auto
info
=
memory
::
Alloc
(
dev_ctx
,
sizeof
(
int
));
int
*
info_d
=
reinterpret_cast
<
int
*>
(
info
->
ptr
());
...
...
@@ -192,6 +193,11 @@ void BatchedOrmqr<platform::CUDADeviceContext, float>(
float
*
a_working_ptr
=
&
a
[
i
*
a_stride
];
float
*
tau_working_ptr
=
&
tau
[
i
*
tau_stride
];
float
*
other_working_ptr
=
&
other
[
i
*
other_stride
];
handle
=
dev_ctx
.
cusolver_dn_handle
();
auto
workspace
=
memory
::
Alloc
(
dev_ctx
,
lwork
*
sizeof
(
float
));
float
*
workspace_ptr
=
reinterpret_cast
<
float
*>
(
workspace
->
ptr
());
// compute ormgr
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusolverDnSormqr
(
handle
,
...
...
@@ -249,8 +255,6 @@ void BatchedOrmqr<platform::CUDADeviceContext, double>(
auto
handle
=
dev_ctx
.
cusolver_dn_handle
();
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusolverDnDormqr_bufferSize
(
handle
,
side
,
trans
,
m
,
n
,
k
,
a
,
lda
,
tau
,
other
,
ldc
,
&
lwork
));
auto
workspace
=
memory
::
Alloc
(
dev_ctx
,
lwork
*
sizeof
(
double
));
double
*
workspace_ptr
=
reinterpret_cast
<
double
*>
(
workspace
->
ptr
());
auto
info
=
memory
::
Alloc
(
dev_ctx
,
sizeof
(
int
));
int
*
info_d
=
reinterpret_cast
<
int
*>
(
info
->
ptr
());
...
...
@@ -258,6 +262,11 @@ void BatchedOrmqr<platform::CUDADeviceContext, double>(
double
*
a_working_ptr
=
&
a
[
i
*
a_stride
];
double
*
tau_working_ptr
=
&
tau
[
i
*
tau_stride
];
double
*
other_working_ptr
=
&
other
[
i
*
other_stride
];
handle
=
dev_ctx
.
cusolver_dn_handle
();
auto
workspace
=
memory
::
Alloc
(
dev_ctx
,
lwork
*
sizeof
(
double
));
double
*
workspace_ptr
=
reinterpret_cast
<
double
*>
(
workspace
->
ptr
());
// compute ormgr
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusolverDnDormqr
(
handle
,
...
...
python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py
浏览文件 @
2af286a6
...
...
@@ -175,6 +175,16 @@ class LinalgLstsqTestCase2(LinalgLstsqTestCase):
self
.
_input_shape_2
=
(
5
,
8
)
class
LinalgLstsqTestCase3
(
LinalgLstsqTestCase
):
def
init_config
(
self
):
self
.
dtype
=
'float64'
self
.
rcond
=
1e-15
self
.
driver
=
"gels"
self
.
_input_shape_1
=
(
10
,
7
,
3
)
self
.
_input_shape_2
=
(
10
,
7
,
6
)
class
LinalgLstsqTestCaseRcond
(
LinalgLstsqTestCase
):
def
init_config
(
self
):
...
...
@@ -192,7 +202,17 @@ class LinalgLstsqTestCaseGelsFloat32(LinalgLstsqTestCase):
self
.
rcond
=
None
self
.
driver
=
"gels"
self
.
_input_shape_1
=
(
10
,
5
)
self
.
_input_shape_2
=
(
10
,
2
)
self
.
_input_shape_2
=
(
10
,
8
)
class
LinalgLstsqTestCaseGelsFloat64
(
LinalgLstsqTestCase
):
def
init_config
(
self
):
self
.
dtype
=
'float32'
self
.
rcond
=
None
self
.
driver
=
"gels"
self
.
_input_shape_1
=
(
3
,
2
,
8
)
self
.
_input_shape_2
=
(
3
,
2
,
15
)
class
LinalgLstsqTestCaseGelssFloat64
(
LinalgLstsqTestCase
):
...
...
@@ -230,9 +250,9 @@ class LinalgLstsqTestCaseBatch2(LinalgLstsqTestCase):
def
init_config
(
self
):
self
.
dtype
=
'float64'
self
.
rcond
=
1e-15
self
.
driver
=
"gels
s
"
self
.
driver
=
"gels"
self
.
_input_shape_1
=
(
10
,
8
,
6
)
self
.
_input_shape_2
=
(
10
,
8
,
2
)
self
.
_input_shape_2
=
(
10
,
8
,
10
)
class
LinalgLstsqTestCaseLarge1
(
LinalgLstsqTestCase
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录