test.py 1.6 KB
Newer Older
Y
yangguohao 已提交
1
from paddle_quantum.circuit import UAnsatz
G
gsq7474741 已提交
2 3 4
import matplotlib.pyplot as plt
from paddle_quantum.utils import plot_density_graph
import numpy as np
Y
yangguohao 已提交
5
import paddle
G
gsq7474741 已提交
6
import unittest
Y
yangguohao 已提交
7 8


Y
yangguohao 已提交
9 10
#density_matrix
def test_density_matrix():
Y
yangguohao 已提交
11
    cir = UAnsatz(1)
Y
yangguohao 已提交
12 13 14 15
    cir.ry(paddle.to_tensor(1,dtype='float64'),0)
    state = cir.run_density_matrix()
    cir.expand(3)
    print(cir.get_state())
Y
yangguohao 已提交
16

Y
yangguohao 已提交
17 18 19 20
    cir2 = UAnsatz(3)
    cir2.ry(paddle.to_tensor(1,dtype='float64'),0)
    cir2.run_density_matrix()
    print(cir2.get_state())
Y
yangguohao 已提交
21

Y
yangguohao 已提交
22 23 24 25 26 27 28
#state_vector
def test_state_vector():
    cir = UAnsatz(1)
    cir.ry(paddle.to_tensor(1,dtype='float64'),0)
    state = cir.run_state_vector()
    cir.expand(3)
    print(cir.get_state())
Y
yangguohao 已提交
29

Y
yangguohao 已提交
30 31 32 33
    cir2 = UAnsatz(3)
    cir2.ry(paddle.to_tensor(1,dtype='float64'),0)
    cir2.run_state_vector()
    print(cir2.get_state())
Y
yangguohao 已提交
34

G
gsq7474741 已提交
35 36 37 38

class TestPlotDensityGraph(unittest.TestCase):
    def setUp(self):
        self.func = plot_density_graph
G
gsq7474741 已提交
39
        self.x_np = (np.random.rand(8, 8) + np.random.rand(8, 8) * 1j)-0.5-0.5j
G
gsq7474741 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
        self.x_tensor = paddle.to_tensor(self.x_np)

    def test_input_type(self):
        self.assertRaises(TypeError, self.func, 1)
        self.assertRaises(TypeError, self.func, [1, 2, 3])

    def test_input_shape(self):
        x = np.zeros((2, 3))
        self.assertRaises(ValueError, self.func, x)

    def test_ndarray_input_inputs(self):
        res = self.func(self.x_np)
        res.show()

    def test_tensor_input(self):
        res = self.func(self.x_tensor)
        res.show()


if __name__ == '__main__':
    test_density_matrix()
    test_state_vector()
    unittest.main()