未验证 提交 86282b41 编写于 作者: Y yangguohao 提交者: GitHub

Update utils.py

上级 809df36c
...@@ -1320,9 +1320,9 @@ def plot_n_qubit_state_in_bloch_sphere( ...@@ -1320,9 +1320,9 @@ def plot_n_qubit_state_in_bloch_sphere(
which_qubits = list(range(n_qubits)) which_qubits = list(range(n_qubits))
else: else:
assert type(which_qubits)==list,'the type of which_qubits should be None or list' assert type(which_qubits)==list,'the type of which_qubits should be None or list'
assert len(which_qubits)==state_len,'展示的量子数量需要小于n_qubits' assert 1<=len(which_qubits)<=n_qubits,'展示的量子数量需要小于n_qubits'
for i in range(len(which_qubits)): for i in range(len(which_qubits)):
assert 0<which_qubits[i]<n_qubits, '0<which_qubits[i]<n_qubits' assert 0<=which_qubits[i]<n_qubits, '0<which_qubits[i]<n_qubits'
# Assign a value to an empty variable # Assign a value to an empty variable
if filename is None: if filename is None:
...@@ -1337,14 +1337,13 @@ def plot_n_qubit_state_in_bloch_sphere( ...@@ -1337,14 +1337,13 @@ def plot_n_qubit_state_in_bloch_sphere(
state = state.numpy() state = state.numpy()
#state_vector to density matrix #state_vector to density matrix
if state.shape[0]>=2 and state.size==2*state.shape[0]: if state.shape[0]>=2 and state.size==state.shape[0]:
state_vector = state state_vector = state
state = np.outer(state_vector, np.conj(state_vector)) state = np.outer(state_vector, np.conj(state_vector))
#多量子态分解 #多量子态分解
if state.shape[0]>2: if state.shape[0]>2:
rho = paddle.to_tensor(state) rho = paddle.to_tensor(state)
print(rho)
tmp_s = [] tmp_s = []
for q in which_qubits: for q in which_qubits:
tmp_s.append(partial_trace_discontiguous(rho,[q])) tmp_s.append(partial_trace_discontiguous(rho,[q]))
...@@ -1387,8 +1386,8 @@ def plot_n_qubit_state_in_bloch_sphere( ...@@ -1387,8 +1386,8 @@ def plot_n_qubit_state_in_bloch_sphere(
# Helper function to plot vectors on a sphere. # Helper function to plot vectors on a sphere.
fig = plt.figure(figsize=(8, 8), dpi=100) fig = plt.figure(figsize=(8, 8), dpi=100)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1) fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
dim = np.ceil(sqrt(n_qubits)) dim = np.ceil(sqrt(len(which_qubits)))
for i in range(1,n_qubits+1): for i in range(1,len(which_qubits)+1):
ax = fig.add_subplot(dim,dim,i,projection='3d') ax = fig.add_subplot(dim,dim,i,projection='3d')
bloch_vector=np.array([bloch_vectors[i-1]]) bloch_vector=np.array([bloch_vectors[i-1]])
__plot_bloch_sphere( __plot_bloch_sphere(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册