Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
64c268b2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
64c268b2
编写于
3月 22, 2022
作者:
Z
zhiboniu
提交者:
GitHub
3月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add more annotations to test_cholesky_solve_op.py, make it an example in hackson guide
上级
7fc0c619
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
26 addition
and
13 deletion
+26
-13
python/paddle/fluid/tests/unittests/test_cholesky_solve_op.py
...on/paddle/fluid/tests/unittests/test_cholesky_solve_op.py
+26
-13
未找到文件。
python/paddle/fluid/tests/unittests/test_cholesky_solve_op.py
浏览文件 @
64c268b2
...
...
@@ -29,6 +29,7 @@ from paddle.fluid import Program, program_guard, core
paddle
.
enable_static
()
#cholesky_solve implement 1
def
cholesky_solution
(
X
,
B
,
upper
=
True
):
if
upper
:
A
=
np
.
triu
(
X
)
...
...
@@ -43,6 +44,7 @@ def cholesky_solution(X, B, upper=True):
L
,
B
,
lower
=
True
))
#cholesky_solve implement 2
def
scipy_cholesky_solution
(
X
,
B
,
upper
=
True
):
if
upper
:
umat
=
np
.
triu
(
X
)
...
...
@@ -54,27 +56,29 @@ def scipy_cholesky_solution(X, B, upper=True):
return
scipy
.
linalg
.
cho_solve
(
K
,
B
)
def
boardcast_shape
(
matA
,
matB
):
#broadcast function used by cholesky_solve
def
broadcast_shape
(
matA
,
matB
):
shapeA
=
matA
.
shape
shapeB
=
matB
.
shape
B
oar
dshape
=
[]
B
roa
dshape
=
[]
for
idx
in
range
(
len
(
shapeA
)
-
2
):
if
shapeA
[
idx
]
==
shapeB
[
idx
]:
B
oar
dshape
.
append
(
shapeA
[
idx
])
B
roa
dshape
.
append
(
shapeA
[
idx
])
continue
elif
shapeA
[
idx
]
==
1
or
shapeB
[
idx
]
==
1
:
B
oar
dshape
.
append
(
max
(
shapeA
[
idx
],
shapeB
[
idx
]))
B
roa
dshape
.
append
(
max
(
shapeA
[
idx
],
shapeB
[
idx
]))
else
:
raise
Exception
(
'shapeA and shapeB should be b
oar
dcasted, but got {} and {}'
.
'shapeA and shapeB should be b
roa
dcasted, but got {} and {}'
.
format
(
shapeA
,
shapeB
))
bsA
=
B
oar
dshape
+
list
(
shapeA
[
-
2
:])
bsB
=
B
oar
dshape
+
list
(
shapeB
[
-
2
:])
bsA
=
B
roa
dshape
+
list
(
shapeA
[
-
2
:])
bsB
=
B
roa
dshape
+
list
(
shapeB
[
-
2
:])
return
np
.
broadcast_to
(
matA
,
bsA
),
np
.
broadcast_to
(
matB
,
bsB
)
#cholesky_solve implement in batch
def
scipy_cholesky_solution_batch
(
bumat
,
bB
,
upper
=
True
):
bumat
,
bB
=
b
oar
dcast_shape
(
bumat
,
bB
)
bumat
,
bB
=
b
roa
dcast_shape
(
bumat
,
bB
)
ushape
=
bumat
.
shape
bshape
=
bB
.
shape
bumat
=
bumat
.
reshape
((
-
1
,
ushape
[
-
2
],
ushape
[
-
1
]))
...
...
@@ -90,18 +94,21 @@ def scipy_cholesky_solution_batch(bumat, bB, upper=True):
return
np
.
array
(
bx
).
reshape
(
bshape
)
# 2D + 2D , , upper=False
# test condition: shape: 2D + 2D , upper=False
# based on OpTest class
class
TestCholeskySolveOp
(
OpTest
):
"""
case 1
"""
#test condition set
def
config
(
self
):
self
.
y_shape
=
[
15
,
15
]
self
.
x_shape
=
[
15
,
5
]
self
.
upper
=
False
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float64
#Here cholesky_solve Op only supports float64/float32 type, please check others if Op supports more types.
#get scipy result
def
set_output
(
self
):
umat
=
self
.
inputs
[
'Y'
]
self
.
output
=
scipy_cholesky_solution_batch
(
...
...
@@ -124,14 +131,16 @@ class TestCholeskySolveOp(OpTest):
self
.
set_output
()
self
.
outputs
=
{
'Out'
:
self
.
output
}
#check Op forward result
def
test_check_output
(
self
):
self
.
check_output
()
#check Op grad
def
test_check_grad_normal
(
self
):
self
.
check_grad
([
'Y'
],
'Out'
,
max_relative_error
=
0.01
)
# 3D(broadcast) + 3D, upper=True
#
test condition:
3D(broadcast) + 3D, upper=True
class
TestCholeskySolveOp3
(
TestCholeskySolveOp
):
"""
case 3
...
...
@@ -144,11 +153,11 @@ class TestCholeskySolveOp3(TestCholeskySolveOp):
self
.
dtype
=
np
.
float64
#API function test
class
TestCholeskySolveAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
np
.
random
.
seed
(
2021
)
self
.
place
=
[
paddle
.
CPUPlace
()]
# self.place = [paddle.CUDAPlace(0)]
self
.
dtype
=
"float64"
self
.
upper
=
True
if
core
.
is_compiled_with_cuda
():
...
...
@@ -177,10 +186,12 @@ class TestCholeskySolveAPI(unittest.TestCase):
fetch_list
=
[
z
])
self
.
assertTrue
(
np
.
allclose
(
fetches
[
0
],
z_np
))
#test in static mode
def
test_static
(
self
):
for
place
in
self
.
place
:
self
.
check_static_result
(
place
=
place
)
#test in dynamic mode
def
test_dygraph
(
self
):
def
run
(
place
):
paddle
.
disable_static
(
place
)
...
...
@@ -199,7 +210,8 @@ class TestCholeskySolveAPI(unittest.TestCase):
for
idx
,
place
in
enumerate
(
self
.
place
):
run
(
place
)
def
test_boardcast
(
self
):
#test input with broadcast
def
test_broadcast
(
self
):
def
run
(
place
):
paddle
.
disable_static
()
x_np
=
np
.
random
.
random
([
1
,
30
,
2
]).
astype
(
self
.
dtype
)
...
...
@@ -218,6 +230,7 @@ class TestCholeskySolveAPI(unittest.TestCase):
run
(
place
)
#test condition out of bounds
class
TestCholeskySolveOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
paddle
.
enable_static
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录