diff --git a/python/paddle/fluid/tests/unittests/test_cholesky_solve_op.py b/python/paddle/fluid/tests/unittests/test_cholesky_solve_op.py index c31594b75e985c3237459c7c6ea6058aaf31fdcc..bb45a52566211311fc2585e78ef6584bf9a06a20 100644 --- a/python/paddle/fluid/tests/unittests/test_cholesky_solve_op.py +++ b/python/paddle/fluid/tests/unittests/test_cholesky_solve_op.py @@ -29,6 +29,7 @@ from paddle.fluid import Program, program_guard, core paddle.enable_static() +#cholesky_solve implement 1 def cholesky_solution(X, B, upper=True): if upper: A = np.triu(X) @@ -43,6 +44,7 @@ def cholesky_solution(X, B, upper=True): L, B, lower=True)) +#cholesky_solve implement 2 def scipy_cholesky_solution(X, B, upper=True): if upper: umat = np.triu(X) @@ -54,27 +56,29 @@ def scipy_cholesky_solution(X, B, upper=True): return scipy.linalg.cho_solve(K, B) -def boardcast_shape(matA, matB): +#broadcast function used by cholesky_solve +def broadcast_shape(matA, matB): shapeA = matA.shape shapeB = matB.shape - Boardshape = [] + Broadshape = [] for idx in range(len(shapeA) - 2): if shapeA[idx] == shapeB[idx]: - Boardshape.append(shapeA[idx]) + Broadshape.append(shapeA[idx]) continue elif shapeA[idx] == 1 or shapeB[idx] == 1: - Boardshape.append(max(shapeA[idx], shapeB[idx])) + Broadshape.append(max(shapeA[idx], shapeB[idx])) else: raise Exception( - 'shapeA and shapeB should be boardcasted, but got {} and {}'. + 'shapeA and shapeB should be broadcasted, but got {} and {}'. format(shapeA, shapeB)) - bsA = Boardshape + list(shapeA[-2:]) - bsB = Boardshape + list(shapeB[-2:]) + bsA = Broadshape + list(shapeA[-2:]) + bsB = Broadshape + list(shapeB[-2:]) return np.broadcast_to(matA, bsA), np.broadcast_to(matB, bsB) +#cholesky_solve implement in batch def scipy_cholesky_solution_batch(bumat, bB, upper=True): - bumat, bB = boardcast_shape(bumat, bB) + bumat, bB = broadcast_shape(bumat, bB) ushape = bumat.shape bshape = bB.shape bumat = bumat.reshape((-1, ushape[-2], ushape[-1])) @@ -90,18 +94,21 @@ def scipy_cholesky_solution_batch(bumat, bB, upper=True): return np.array(bx).reshape(bshape) -# 2D + 2D , , upper=False +# test condition: shape: 2D + 2D , upper=False +# based on OpTest class class TestCholeskySolveOp(OpTest): """ case 1 """ + #test condition set def config(self): self.y_shape = [15, 15] self.x_shape = [15, 5] self.upper = False - self.dtype = np.float64 + self.dtype = np.float64 #Here cholesky_solve Op only supports float64/float32 type, please check others if Op supports more types. + #get scipy result def set_output(self): umat = self.inputs['Y'] self.output = scipy_cholesky_solution_batch( @@ -124,14 +131,16 @@ class TestCholeskySolveOp(OpTest): self.set_output() self.outputs = {'Out': self.output} + #check Op forward result def test_check_output(self): self.check_output() + #check Op grad def test_check_grad_normal(self): self.check_grad(['Y'], 'Out', max_relative_error=0.01) -# 3D(broadcast) + 3D, upper=True +# test condition: 3D(broadcast) + 3D, upper=True class TestCholeskySolveOp3(TestCholeskySolveOp): """ case 3 @@ -144,11 +153,11 @@ class TestCholeskySolveOp3(TestCholeskySolveOp): self.dtype = np.float64 +#API function test class TestCholeskySolveAPI(unittest.TestCase): def setUp(self): np.random.seed(2021) self.place = [paddle.CPUPlace()] - # self.place = [paddle.CUDAPlace(0)] self.dtype = "float64" self.upper = True if core.is_compiled_with_cuda(): @@ -177,10 +186,12 @@ class TestCholeskySolveAPI(unittest.TestCase): fetch_list=[z]) self.assertTrue(np.allclose(fetches[0], z_np)) + #test in static mode def test_static(self): for place in self.place: self.check_static_result(place=place) + #test in dynamic mode def test_dygraph(self): def run(place): paddle.disable_static(place) @@ -199,7 +210,8 @@ class TestCholeskySolveAPI(unittest.TestCase): for idx, place in enumerate(self.place): run(place) - def test_boardcast(self): + #test input with broadcast + def test_broadcast(self): def run(place): paddle.disable_static() x_np = np.random.random([1, 30, 2]).astype(self.dtype) @@ -218,6 +230,7 @@ class TestCholeskySolveAPI(unittest.TestCase): run(place) +#test condition out of bounds class TestCholeskySolveOpError(unittest.TestCase): def test_errors(self): paddle.enable_static()