未验证 提交 64c268b2 编写于 作者: Z zhiboniu 提交者: GitHub

Add more annotations to test_cholesky_solve_op.py, make it an example in hackson guide

上级 7fc0c619
......@@ -29,6 +29,7 @@ from paddle.fluid import Program, program_guard, core
paddle.enable_static()
#cholesky_solve implement 1
def cholesky_solution(X, B, upper=True):
if upper:
A = np.triu(X)
......@@ -43,6 +44,7 @@ def cholesky_solution(X, B, upper=True):
L, B, lower=True))
#cholesky_solve implement 2
def scipy_cholesky_solution(X, B, upper=True):
if upper:
umat = np.triu(X)
......@@ -54,27 +56,29 @@ def scipy_cholesky_solution(X, B, upper=True):
return scipy.linalg.cho_solve(K, B)
def boardcast_shape(matA, matB):
#broadcast function used by cholesky_solve
def broadcast_shape(matA, matB):
shapeA = matA.shape
shapeB = matB.shape
Boardshape = []
Broadshape = []
for idx in range(len(shapeA) - 2):
if shapeA[idx] == shapeB[idx]:
Boardshape.append(shapeA[idx])
Broadshape.append(shapeA[idx])
continue
elif shapeA[idx] == 1 or shapeB[idx] == 1:
Boardshape.append(max(shapeA[idx], shapeB[idx]))
Broadshape.append(max(shapeA[idx], shapeB[idx]))
else:
raise Exception(
'shapeA and shapeB should be boardcasted, but got {} and {}'.
'shapeA and shapeB should be broadcasted, but got {} and {}'.
format(shapeA, shapeB))
bsA = Boardshape + list(shapeA[-2:])
bsB = Boardshape + list(shapeB[-2:])
bsA = Broadshape + list(shapeA[-2:])
bsB = Broadshape + list(shapeB[-2:])
return np.broadcast_to(matA, bsA), np.broadcast_to(matB, bsB)
#cholesky_solve implement in batch
def scipy_cholesky_solution_batch(bumat, bB, upper=True):
bumat, bB = boardcast_shape(bumat, bB)
bumat, bB = broadcast_shape(bumat, bB)
ushape = bumat.shape
bshape = bB.shape
bumat = bumat.reshape((-1, ushape[-2], ushape[-1]))
......@@ -90,18 +94,21 @@ def scipy_cholesky_solution_batch(bumat, bB, upper=True):
return np.array(bx).reshape(bshape)
# 2D + 2D , , upper=False
# test condition: shape: 2D + 2D , upper=False
# based on OpTest class
class TestCholeskySolveOp(OpTest):
"""
case 1
"""
#test condition set
def config(self):
self.y_shape = [15, 15]
self.x_shape = [15, 5]
self.upper = False
self.dtype = np.float64
self.dtype = np.float64 #Here cholesky_solve Op only supports float64/float32 type, please check others if Op supports more types.
#get scipy result
def set_output(self):
umat = self.inputs['Y']
self.output = scipy_cholesky_solution_batch(
......@@ -124,14 +131,16 @@ class TestCholeskySolveOp(OpTest):
self.set_output()
self.outputs = {'Out': self.output}
#check Op forward result
def test_check_output(self):
self.check_output()
#check Op grad
def test_check_grad_normal(self):
self.check_grad(['Y'], 'Out', max_relative_error=0.01)
# 3D(broadcast) + 3D, upper=True
# test condition: 3D(broadcast) + 3D, upper=True
class TestCholeskySolveOp3(TestCholeskySolveOp):
"""
case 3
......@@ -144,11 +153,11 @@ class TestCholeskySolveOp3(TestCholeskySolveOp):
self.dtype = np.float64
#API function test
class TestCholeskySolveAPI(unittest.TestCase):
def setUp(self):
np.random.seed(2021)
self.place = [paddle.CPUPlace()]
# self.place = [paddle.CUDAPlace(0)]
self.dtype = "float64"
self.upper = True
if core.is_compiled_with_cuda():
......@@ -177,10 +186,12 @@ class TestCholeskySolveAPI(unittest.TestCase):
fetch_list=[z])
self.assertTrue(np.allclose(fetches[0], z_np))
#test in static mode
def test_static(self):
for place in self.place:
self.check_static_result(place=place)
#test in dynamic mode
def test_dygraph(self):
def run(place):
paddle.disable_static(place)
......@@ -199,7 +210,8 @@ class TestCholeskySolveAPI(unittest.TestCase):
for idx, place in enumerate(self.place):
run(place)
def test_boardcast(self):
#test input with broadcast
def test_broadcast(self):
def run(place):
paddle.disable_static()
x_np = np.random.random([1, 30, 2]).astype(self.dtype)
......@@ -218,6 +230,7 @@ class TestCholeskySolveAPI(unittest.TestCase):
run(place)
#test condition out of bounds
class TestCholeskySolveOpError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册