未验证 提交 64c268b2 编写于 作者: Z zhiboniu 提交者: GitHub

Add more annotations to test_cholesky_solve_op.py, make it an example in hackson guide

上级 7fc0c619
...@@ -29,6 +29,7 @@ from paddle.fluid import Program, program_guard, core ...@@ -29,6 +29,7 @@ from paddle.fluid import Program, program_guard, core
paddle.enable_static() paddle.enable_static()
#cholesky_solve implement 1
def cholesky_solution(X, B, upper=True): def cholesky_solution(X, B, upper=True):
if upper: if upper:
A = np.triu(X) A = np.triu(X)
...@@ -43,6 +44,7 @@ def cholesky_solution(X, B, upper=True): ...@@ -43,6 +44,7 @@ def cholesky_solution(X, B, upper=True):
L, B, lower=True)) L, B, lower=True))
#cholesky_solve implement 2
def scipy_cholesky_solution(X, B, upper=True): def scipy_cholesky_solution(X, B, upper=True):
if upper: if upper:
umat = np.triu(X) umat = np.triu(X)
...@@ -54,27 +56,29 @@ def scipy_cholesky_solution(X, B, upper=True): ...@@ -54,27 +56,29 @@ def scipy_cholesky_solution(X, B, upper=True):
return scipy.linalg.cho_solve(K, B) return scipy.linalg.cho_solve(K, B)
def boardcast_shape(matA, matB): #broadcast function used by cholesky_solve
def broadcast_shape(matA, matB):
shapeA = matA.shape shapeA = matA.shape
shapeB = matB.shape shapeB = matB.shape
Boardshape = [] Broadshape = []
for idx in range(len(shapeA) - 2): for idx in range(len(shapeA) - 2):
if shapeA[idx] == shapeB[idx]: if shapeA[idx] == shapeB[idx]:
Boardshape.append(shapeA[idx]) Broadshape.append(shapeA[idx])
continue continue
elif shapeA[idx] == 1 or shapeB[idx] == 1: elif shapeA[idx] == 1 or shapeB[idx] == 1:
Boardshape.append(max(shapeA[idx], shapeB[idx])) Broadshape.append(max(shapeA[idx], shapeB[idx]))
else: else:
raise Exception( raise Exception(
'shapeA and shapeB should be boardcasted, but got {} and {}'. 'shapeA and shapeB should be broadcasted, but got {} and {}'.
format(shapeA, shapeB)) format(shapeA, shapeB))
bsA = Boardshape + list(shapeA[-2:]) bsA = Broadshape + list(shapeA[-2:])
bsB = Boardshape + list(shapeB[-2:]) bsB = Broadshape + list(shapeB[-2:])
return np.broadcast_to(matA, bsA), np.broadcast_to(matB, bsB) return np.broadcast_to(matA, bsA), np.broadcast_to(matB, bsB)
#cholesky_solve implement in batch
def scipy_cholesky_solution_batch(bumat, bB, upper=True): def scipy_cholesky_solution_batch(bumat, bB, upper=True):
bumat, bB = boardcast_shape(bumat, bB) bumat, bB = broadcast_shape(bumat, bB)
ushape = bumat.shape ushape = bumat.shape
bshape = bB.shape bshape = bB.shape
bumat = bumat.reshape((-1, ushape[-2], ushape[-1])) bumat = bumat.reshape((-1, ushape[-2], ushape[-1]))
...@@ -90,18 +94,21 @@ def scipy_cholesky_solution_batch(bumat, bB, upper=True): ...@@ -90,18 +94,21 @@ def scipy_cholesky_solution_batch(bumat, bB, upper=True):
return np.array(bx).reshape(bshape) return np.array(bx).reshape(bshape)
# 2D + 2D , , upper=False # test condition: shape: 2D + 2D , upper=False
# based on OpTest class
class TestCholeskySolveOp(OpTest): class TestCholeskySolveOp(OpTest):
""" """
case 1 case 1
""" """
#test condition set
def config(self): def config(self):
self.y_shape = [15, 15] self.y_shape = [15, 15]
self.x_shape = [15, 5] self.x_shape = [15, 5]
self.upper = False self.upper = False
self.dtype = np.float64 self.dtype = np.float64 #Here cholesky_solve Op only supports float64/float32 type, please check others if Op supports more types.
#get scipy result
def set_output(self): def set_output(self):
umat = self.inputs['Y'] umat = self.inputs['Y']
self.output = scipy_cholesky_solution_batch( self.output = scipy_cholesky_solution_batch(
...@@ -124,14 +131,16 @@ class TestCholeskySolveOp(OpTest): ...@@ -124,14 +131,16 @@ class TestCholeskySolveOp(OpTest):
self.set_output() self.set_output()
self.outputs = {'Out': self.output} self.outputs = {'Out': self.output}
#check Op forward result
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
#check Op grad
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['Y'], 'Out', max_relative_error=0.01) self.check_grad(['Y'], 'Out', max_relative_error=0.01)
# 3D(broadcast) + 3D, upper=True # test condition: 3D(broadcast) + 3D, upper=True
class TestCholeskySolveOp3(TestCholeskySolveOp): class TestCholeskySolveOp3(TestCholeskySolveOp):
""" """
case 3 case 3
...@@ -144,11 +153,11 @@ class TestCholeskySolveOp3(TestCholeskySolveOp): ...@@ -144,11 +153,11 @@ class TestCholeskySolveOp3(TestCholeskySolveOp):
self.dtype = np.float64 self.dtype = np.float64
#API function test
class TestCholeskySolveAPI(unittest.TestCase): class TestCholeskySolveAPI(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(2021) np.random.seed(2021)
self.place = [paddle.CPUPlace()] self.place = [paddle.CPUPlace()]
# self.place = [paddle.CUDAPlace(0)]
self.dtype = "float64" self.dtype = "float64"
self.upper = True self.upper = True
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
...@@ -177,10 +186,12 @@ class TestCholeskySolveAPI(unittest.TestCase): ...@@ -177,10 +186,12 @@ class TestCholeskySolveAPI(unittest.TestCase):
fetch_list=[z]) fetch_list=[z])
self.assertTrue(np.allclose(fetches[0], z_np)) self.assertTrue(np.allclose(fetches[0], z_np))
#test in static mode
def test_static(self): def test_static(self):
for place in self.place: for place in self.place:
self.check_static_result(place=place) self.check_static_result(place=place)
#test in dynamic mode
def test_dygraph(self): def test_dygraph(self):
def run(place): def run(place):
paddle.disable_static(place) paddle.disable_static(place)
...@@ -199,7 +210,8 @@ class TestCholeskySolveAPI(unittest.TestCase): ...@@ -199,7 +210,8 @@ class TestCholeskySolveAPI(unittest.TestCase):
for idx, place in enumerate(self.place): for idx, place in enumerate(self.place):
run(place) run(place)
def test_boardcast(self): #test input with broadcast
def test_broadcast(self):
def run(place): def run(place):
paddle.disable_static() paddle.disable_static()
x_np = np.random.random([1, 30, 2]).astype(self.dtype) x_np = np.random.random([1, 30, 2]).astype(self.dtype)
...@@ -218,6 +230,7 @@ class TestCholeskySolveAPI(unittest.TestCase): ...@@ -218,6 +230,7 @@ class TestCholeskySolveAPI(unittest.TestCase):
run(place) run(place)
#test condition out of bounds
class TestCholeskySolveOpError(unittest.TestCase): class TestCholeskySolveOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
paddle.enable_static() paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册