Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
64c268b2
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
64c268b2
编写于
3月 22, 2022
作者:
Z
zhiboniu
提交者:
GitHub
3月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add more annotations to test_cholesky_solve_op.py, make it an example in hackson guide
上级
7fc0c619
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
26 addition
and
13 deletion
+26
-13
python/paddle/fluid/tests/unittests/test_cholesky_solve_op.py
...on/paddle/fluid/tests/unittests/test_cholesky_solve_op.py
+26
-13
未找到文件。
python/paddle/fluid/tests/unittests/test_cholesky_solve_op.py
浏览文件 @
64c268b2
...
...
@@ -29,6 +29,7 @@ from paddle.fluid import Program, program_guard, core
paddle
.
enable_static
()
#cholesky_solve implement 1
def
cholesky_solution
(
X
,
B
,
upper
=
True
):
if
upper
:
A
=
np
.
triu
(
X
)
...
...
@@ -43,6 +44,7 @@ def cholesky_solution(X, B, upper=True):
L
,
B
,
lower
=
True
))
#cholesky_solve implement 2
def
scipy_cholesky_solution
(
X
,
B
,
upper
=
True
):
if
upper
:
umat
=
np
.
triu
(
X
)
...
...
@@ -54,27 +56,29 @@ def scipy_cholesky_solution(X, B, upper=True):
return
scipy
.
linalg
.
cho_solve
(
K
,
B
)
def
boardcast_shape
(
matA
,
matB
):
#broadcast function used by cholesky_solve
def
broadcast_shape
(
matA
,
matB
):
shapeA
=
matA
.
shape
shapeB
=
matB
.
shape
B
oar
dshape
=
[]
B
roa
dshape
=
[]
for
idx
in
range
(
len
(
shapeA
)
-
2
):
if
shapeA
[
idx
]
==
shapeB
[
idx
]:
B
oar
dshape
.
append
(
shapeA
[
idx
])
B
roa
dshape
.
append
(
shapeA
[
idx
])
continue
elif
shapeA
[
idx
]
==
1
or
shapeB
[
idx
]
==
1
:
B
oar
dshape
.
append
(
max
(
shapeA
[
idx
],
shapeB
[
idx
]))
B
roa
dshape
.
append
(
max
(
shapeA
[
idx
],
shapeB
[
idx
]))
else
:
raise
Exception
(
'shapeA and shapeB should be b
oar
dcasted, but got {} and {}'
.
'shapeA and shapeB should be b
roa
dcasted, but got {} and {}'
.
format
(
shapeA
,
shapeB
))
bsA
=
B
oar
dshape
+
list
(
shapeA
[
-
2
:])
bsB
=
B
oar
dshape
+
list
(
shapeB
[
-
2
:])
bsA
=
B
roa
dshape
+
list
(
shapeA
[
-
2
:])
bsB
=
B
roa
dshape
+
list
(
shapeB
[
-
2
:])
return
np
.
broadcast_to
(
matA
,
bsA
),
np
.
broadcast_to
(
matB
,
bsB
)
#cholesky_solve implement in batch
def
scipy_cholesky_solution_batch
(
bumat
,
bB
,
upper
=
True
):
bumat
,
bB
=
b
oar
dcast_shape
(
bumat
,
bB
)
bumat
,
bB
=
b
roa
dcast_shape
(
bumat
,
bB
)
ushape
=
bumat
.
shape
bshape
=
bB
.
shape
bumat
=
bumat
.
reshape
((
-
1
,
ushape
[
-
2
],
ushape
[
-
1
]))
...
...
@@ -90,18 +94,21 @@ def scipy_cholesky_solution_batch(bumat, bB, upper=True):
return
np
.
array
(
bx
).
reshape
(
bshape
)
# 2D + 2D , , upper=False
# test condition: shape: 2D + 2D , upper=False
# based on OpTest class
class
TestCholeskySolveOp
(
OpTest
):
"""
case 1
"""
#test condition set
def
config
(
self
):
self
.
y_shape
=
[
15
,
15
]
self
.
x_shape
=
[
15
,
5
]
self
.
upper
=
False
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float64
#Here cholesky_solve Op only supports float64/float32 type, please check others if Op supports more types.
#get scipy result
def
set_output
(
self
):
umat
=
self
.
inputs
[
'Y'
]
self
.
output
=
scipy_cholesky_solution_batch
(
...
...
@@ -124,14 +131,16 @@ class TestCholeskySolveOp(OpTest):
self
.
set_output
()
self
.
outputs
=
{
'Out'
:
self
.
output
}
#check Op forward result
def
test_check_output
(
self
):
self
.
check_output
()
#check Op grad
def
test_check_grad_normal
(
self
):
self
.
check_grad
([
'Y'
],
'Out'
,
max_relative_error
=
0.01
)
# 3D(broadcast) + 3D, upper=True
#
test condition:
3D(broadcast) + 3D, upper=True
class
TestCholeskySolveOp3
(
TestCholeskySolveOp
):
"""
case 3
...
...
@@ -144,11 +153,11 @@ class TestCholeskySolveOp3(TestCholeskySolveOp):
self
.
dtype
=
np
.
float64
#API function test
class
TestCholeskySolveAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
np
.
random
.
seed
(
2021
)
self
.
place
=
[
paddle
.
CPUPlace
()]
# self.place = [paddle.CUDAPlace(0)]
self
.
dtype
=
"float64"
self
.
upper
=
True
if
core
.
is_compiled_with_cuda
():
...
...
@@ -177,10 +186,12 @@ class TestCholeskySolveAPI(unittest.TestCase):
fetch_list
=
[
z
])
self
.
assertTrue
(
np
.
allclose
(
fetches
[
0
],
z_np
))
#test in static mode
def
test_static
(
self
):
for
place
in
self
.
place
:
self
.
check_static_result
(
place
=
place
)
#test in dynamic mode
def
test_dygraph
(
self
):
def
run
(
place
):
paddle
.
disable_static
(
place
)
...
...
@@ -199,7 +210,8 @@ class TestCholeskySolveAPI(unittest.TestCase):
for
idx
,
place
in
enumerate
(
self
.
place
):
run
(
place
)
def
test_boardcast
(
self
):
#test input with broadcast
def
test_broadcast
(
self
):
def
run
(
place
):
paddle
.
disable_static
()
x_np
=
np
.
random
.
random
([
1
,
30
,
2
]).
astype
(
self
.
dtype
)
...
...
@@ -218,6 +230,7 @@ class TestCholeskySolveAPI(unittest.TestCase):
run
(
place
)
#test condition out of bounds
class
TestCholeskySolveOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
paddle
.
enable_static
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录