diff --git a/paddle_quantum/utils.py b/paddle_quantum/utils.py index 9b72548807d04ae7dcb9cbe338b30b0f6a8476fb..81d476720907e527b018e6869aa487c41ea1cd7b 100644 --- a/paddle_quantum/utils.py +++ b/paddle_quantum/utils.py @@ -1310,14 +1310,7 @@ def plot_n_qubit_state_in_bloch_sphere( for i in range(state_len): assert type(state[i]) == paddle.Tensor or type(state[i]) == np.ndarray, \ 'the type of "state[i]" should be "paddle.Tensor" or "numpy.ndarray".' - - if show_qubits is None: - show_qubits = [None]*state_len - else: - assert len(show_qubits)==state_len,'show_qubits大小需要和state相同' - for i in range(state_len): - assert type(show_qubits[i])==list,'the type of show_qubits should be None or list' - + # Convert Tensor to numpy for i in range(state_len): if type(state[i]) == paddle.Tensor: @@ -1328,23 +1321,33 @@ def plot_n_qubit_state_in_bloch_sphere( if state[i].size == 2: state_vector = state[i] state[i] = np.outer(state_vector, np.conj(state_vector)) + if show_qubits is None: + show_qubits = [None]*state_len + else: + assert len(show_qubits)==state_len,'show_qubits大小需要和state相同' + for i in range(state_len): + assert type(show_qubits[i])==list,'the type of show_qubits should be None or list' + assert 02: s = [] if show_qubits[i] is None: - qubits_list = [*range(int(np.log2(state[i].shape[0])))] + qubits_list = list(range(int(np.log2(state[i].shape[0])))) else: qubits_list = show_qubits[i] rho = paddle.to_tensor(state[i]) for q in qubits_list: s.append(partial_trace_discontiguous(rho,[q])) - plot_state_in_bloch_sphere(s,**args) + #多量子态的子图的箭头向量改为蓝色 + plot_state_in_bloch_sphere(s,n_qubit=len(qubits_list),set_color='#0000FF',**args) else: plot_state_in_bloch_sphere(state[i],**args) def plot_state_in_bloch_sphere( state, + qubits_list=None, show_arrow=False, save_gif=False, filename=None, @@ -1356,6 +1359,7 @@ def plot_state_in_bloch_sphere( Args: state (list(numpy.ndarray or paddle.Tensor)): 输入的量子态列表,可以支持态矢量和密度矩阵 + n_qubit (int): 若为多量子态,则为需要展示的量子比特的数目 show_arrow (bool): 是否展示向量的箭头,默认为 ``False`` save_gif (bool): 是否存储 gif 动图,默认为 ``False`` filename (str): 存储的 gif 动图的名字 @@ -1433,13 +1437,23 @@ def plot_state_in_bloch_sphere( # Helper function to plot vectors on a sphere. fig = plt.figure(figsize=(8, 8), dpi=100) fig.subplots_adjust(left=0, right=1, bottom=0, top=1) - ax = fig.add_subplot(111, projection='3d') - - __plot_bloch_sphere( + + if qubits_list is None:#若为单量子态 + ax = fig.add_subplot(111, projection='3d') + __plot_bloch_sphere( ax, bloch_vectors, show_arrow, clear_plt=True, view_angle=view_angle, view_dist=view_dist, set_color=set_color - ) - + ) + else: #若为多量子态 + dim = np.ceil(sqrt(n_qubit)) + for i in range(1,n_qubit+1): + ax = fig.add_subplot(dim,dim,i,projection='3d') + bloch_vector=np.array([bloch_vectors[i-1]]) + __plot_bloch_sphere( + ax, bloch_vector, show_arrow, clear_plt=True, + view_angle=view_angle, view_dist=view_dist, set_color=set_color + ) + plt.show()