diff --git a/paddle_quantum/utils.py b/paddle_quantum/utils.py index 593ec71a84bb97892e0d97822847e5dd507143e4..f8ebb7c5e8d9671be1cf743c2c76dc6426684372 100644 --- a/paddle_quantum/utils.py +++ b/paddle_quantum/utils.py @@ -1285,6 +1285,7 @@ def __plot_bloch_sphere( def plot_n_qubit_state_in_bloch_sphere( state, which_qubits=None, + show_arrow=False, save_gif=False, filename=None, view_angle=None, @@ -1313,10 +1314,10 @@ def plot_n_qubit_state_in_bloch_sphere( assert type(set_color) == str, \ 'the type of "set_color" should be "str".' - n_qubits = np.log2(state.shape[0]) + n_qubits = int(np.log2(state.shape[0])) if which_qubits is None: - which_qubits = list(range(int(n_qubits))) + which_qubits = list(range(n_qubits)) else: assert type(show_qubits)==list,'the type of which_qubits should be None or list' assert len(show_qubits)==state_len,'展示的量子数量需要小于n_qubits' @@ -1380,8 +1381,8 @@ def plot_n_qubit_state_in_bloch_sphere( # Helper function to plot vectors on a sphere. fig = plt.figure(figsize=(8, 8), dpi=100) fig.subplots_adjust(left=0, right=1, bottom=0, top=1) - dim = np.ceil(sqrt(n_qubit)) - for i in range(1,n_qubit+1): + dim = np.ceil(sqrt(n_qubits)) + for i in range(1,n_qubits+1): ax = fig.add_subplot(dim,dim,i,projection='3d') bloch_vector=np.array([bloch_vectors[i-1]]) __plot_bloch_sphere(