未验证 提交 72375b21 编写于 作者: Q QuLeaf 提交者: GitHub

update the contents to adapt the paddle-quantum v2.3.0 (#5712)

上级 ae6828f4
......@@ -2,17 +2,17 @@
"cells": [
{
"cell_type": "markdown",
"id": "chronic-tunisia",
"id": "0b1a3a86",
"metadata": {},
"source": [
"## 1. VSQL 模型简介\n",
"\n",
"变分影子量子学习(variational shadow quantum learning, VSQL)是一个在监督学习框架下的量子–经典混合算法。它使用了参数化量子电路(parameterized quantum circuit, PQC)和经典影子(classical shadow),和通常使用的变分量子算法(variational quantum alogorithm, VQA)不同的是,VSQL 只从子空间获取局部特征,而不是从整个希尔伯特空间获取特征。"
"变分影子量子学习(Variational Shadow Quantum Learning, VSQL)是一个在监督学习框架下的量子–经典混合算法。它使用了参数化量子电路(Parameterized Quantum Circuit, PQC)和经典影子(classical shadow),和通常使用的变分量子算法(Variational Quantum Algorithm, VQA)不同的是,VSQL 只从子空间获取局部特征,而不是从整个希尔伯特空间获取特征。"
]
},
{
"cell_type": "markdown",
"id": "8429d648",
"id": "6aa97879",
"metadata": {},
"source": [
"## 2. 模型原理简介\n",
......@@ -95,229 +95,75 @@
"source": [
"## 4. 模型如何使用\n",
"\n",
"按照如下代码来配置环境:"
"按照如下代码来安装量桨:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "eb7a2be4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://mirrors.bfsu.edu.cn/pypi/web/simple\n",
"Requirement already satisfied: paddle-quantum in /Users/wangzihe/temp/baidu/QPlatform/PaddleQu (2.2.1)\n",
"Requirement already satisfied: paddlepaddle<=2.3.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddle-quantum) (2.3.0)\n",
"Requirement already satisfied: scipy in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddle-quantum) (1.7.1)\n",
"Collecting protobuf<=3.20.1\n",
" Using cached https://mirrors.bfsu.edu.cn/pypi/web/packages/92/0e/b8a60441178c8725fb3afa648e80c312a77feab31e7831d69c672b3c18cc/protobuf-3.20.1-cp37-cp37m-macosx_10_9_x86_64.whl (961 kB)\n",
"Requirement already satisfied: networkx>=2.5 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddle-quantum) (2.6.3)\n",
"Requirement already satisfied: qcompute in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddle-quantum) (3.0.0)\n",
"Requirement already satisfied: matplotlib>=3.3.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddle-quantum) (3.5.2)\n",
"Requirement already satisfied: tqdm in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddle-quantum) (4.64.0)\n",
"Requirement already satisfied: openfermion in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddle-quantum) (1.5.1)\n",
"Requirement already satisfied: opencv-python in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddle-quantum) (4.6.0.66)\n",
"Requirement already satisfied: scikit-learn in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddle-quantum) (1.0.2)\n",
"Requirement already satisfied: fastdtw in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddle-quantum) (0.3.4)\n",
"Requirement already satisfied: cvxpy in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddle-quantum) (1.2.2)\n",
"Requirement already satisfied: rich in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddle-quantum) (12.0.1)\n",
"Requirement already satisfied: pyscf in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddle-quantum) (2.1.1)\n",
"Requirement already satisfied: pillow>=6.2.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from matplotlib>=3.3.0->paddle-quantum) (9.3.0)\n",
"Requirement already satisfied: pyparsing>=2.2.1 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from matplotlib>=3.3.0->paddle-quantum) (3.0.9)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from matplotlib>=3.3.0->paddle-quantum) (2.8.2)\n",
"Requirement already satisfied: packaging>=20.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from matplotlib>=3.3.0->paddle-quantum) (21.3)\n",
"Requirement already satisfied: cycler>=0.10 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from matplotlib>=3.3.0->paddle-quantum) (0.11.0)\n",
"Requirement already satisfied: numpy>=1.17 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from matplotlib>=3.3.0->paddle-quantum) (1.21.3)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from matplotlib>=3.3.0->paddle-quantum) (4.38.0)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from matplotlib>=3.3.0->paddle-quantum) (1.4.4)\n",
"Requirement already satisfied: paddle-bfloat==0.1.2 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddlepaddle<=2.3.0->paddle-quantum) (0.1.2)\n",
"Requirement already satisfied: astor in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddlepaddle<=2.3.0->paddle-quantum) (0.8.1)\n",
"Requirement already satisfied: requests>=2.20.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddlepaddle<=2.3.0->paddle-quantum) (2.28.0)\n",
"Requirement already satisfied: six in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddlepaddle<=2.3.0->paddle-quantum) (1.16.0)\n",
"Requirement already satisfied: decorator in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddlepaddle<=2.3.0->paddle-quantum) (5.1.1)\n",
"Requirement already satisfied: opt-einsum==3.3.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from paddlepaddle<=2.3.0->paddle-quantum) (3.3.0)\n",
"Requirement already satisfied: scs>=1.1.6 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from cvxpy->paddle-quantum) (3.2.2)\n",
"Requirement already satisfied: osqp>=0.4.1 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from cvxpy->paddle-quantum) (0.6.2.post5)\n",
"Requirement already satisfied: ecos>=2 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from cvxpy->paddle-quantum) (2.0.10)\n",
"Requirement already satisfied: cirq-google>=0.15.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from openfermion->paddle-quantum) (1.0.0)\n",
"Requirement already satisfied: sympy in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from openfermion->paddle-quantum) (1.10.1)\n",
"Requirement already satisfied: deprecation in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from openfermion->paddle-quantum) (2.1.0)\n",
"Requirement already satisfied: cirq-core>=0.15.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from openfermion->paddle-quantum) (1.0.0)\n",
"Requirement already satisfied: pubchempy in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from openfermion->paddle-quantum) (1.0.4)\n",
"Requirement already satisfied: h5py>=2.8 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from openfermion->paddle-quantum) (3.7.0)\n",
"Collecting scipy\n",
" Using cached https://mirrors.bfsu.edu.cn/pypi/web/packages/4c/4a/440cc9703938bbc86636ff6b9e17810f3d0f06e9b41891c5433dc4cd9091/scipy-1.1.0-cp37-cp37m-macosx_10_6_intel.macosx_10_9_intel.macosx_10_9_x86_64.macosx_10_10_intel.macosx_10_10_x86_64.whl (16.7 MB)\n",
"Collecting qcompute\n",
" Using cached https://mirrors.bfsu.edu.cn/pypi/web/packages/f6/57/823a3f3e3fd6e327453c4d028751fee7292784e1f2447ae6b7c0f3cc6565/qcompute-2.0.6-py3-none-any.whl (172 kB)\n",
" Using cached https://mirrors.bfsu.edu.cn/pypi/web/packages/8e/03/0a64ec2b7e6395fa53688ad3e489163218c45f36f0c96b448e0279391538/qcompute-2.0.4-py3-none-any.whl (96 kB)\n",
"Requirement already satisfied: bidict in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from qcompute->paddle-quantum) (0.22.0)\n",
"Requirement already satisfied: pyprimes in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from qcompute->paddle-quantum) (0.1)\n",
"Requirement already satisfied: bce-python-sdk in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from qcompute->paddle-quantum) (0.8.73)\n",
"Requirement already satisfied: commonmark<0.10.0,>=0.9.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from rich->paddle-quantum) (0.9.1)\n",
"Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from rich->paddle-quantum) (2.13.0)\n",
"Requirement already satisfied: typing-extensions<5.0,>=3.7.4 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from rich->paddle-quantum) (4.3.0)\n",
"Requirement already satisfied: joblib>=0.11 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from scikit-learn->paddle-quantum) (1.2.0)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from scikit-learn->paddle-quantum) (3.1.0)\n",
"Requirement already satisfied: backports.cached-property~=1.0.1 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from cirq-core>=0.15.0->openfermion->paddle-quantum) (1.0.2)\n",
"Requirement already satisfied: duet~=0.2.7 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from cirq-core>=0.15.0->openfermion->paddle-quantum) (0.2.7)\n",
"Requirement already satisfied: pandas in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from cirq-core>=0.15.0->openfermion->paddle-quantum) (1.3.5)\n",
"Requirement already satisfied: sortedcontainers~=2.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from cirq-core>=0.15.0->openfermion->paddle-quantum) (2.4.0)\n",
"Requirement already satisfied: google-api-core[grpc]<2.0.0dev,>=1.14.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from cirq-google>=0.15.0->openfermion->paddle-quantum) (1.33.2)\n",
"Requirement already satisfied: proto-plus>=1.20.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from cirq-google>=0.15.0->openfermion->paddle-quantum) (1.22.1)\n",
"Requirement already satisfied: qdldl in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from osqp>=0.4.1->cvxpy->paddle-quantum) (0.1.5.post2)\n",
"Requirement already satisfied: charset-normalizer~=2.0.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle<=2.3.0->paddle-quantum) (2.0.12)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle<=2.3.0->paddle-quantum) (2022.9.24)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle<=2.3.0->paddle-quantum) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle<=2.3.0->paddle-quantum) (1.26.12)\n",
"Requirement already satisfied: future>=0.6.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from bce-python-sdk->qcompute->paddle-quantum) (0.18.2)\n",
"Requirement already satisfied: pycryptodome>=3.8.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from bce-python-sdk->qcompute->paddle-quantum) (3.15.0)\n",
"Requirement already satisfied: mpmath>=0.19 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from sympy->openfermion->paddle-quantum) (1.2.1)\n",
"Collecting typing-extensions<5.0,>=3.7.4\n",
" Using cached https://mirrors.bfsu.edu.cn/pypi/web/packages/2e/35/6c4fff5ab443b57116cb1aad46421fb719bed2825664e8fe77d66d99bcbc/typing_extensions-3.10.0.0-py3-none-any.whl (26 kB)\n",
"Collecting protobuf<=3.20.1\n",
" Using cached https://mirrors.bfsu.edu.cn/pypi/web/packages/ea/fe/82cf68917308b208731487f986db209e56903c30e324499b6bf0cc6a6203/protobuf-3.19.6-cp37-cp37m-macosx_10_9_x86_64.whl (979 kB)\n",
"Requirement already satisfied: google-auth<3.0dev,>=1.25.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from google-api-core[grpc]<2.0.0dev,>=1.14.0->cirq-google>=0.15.0->openfermion->paddle-quantum) (2.14.1)\n",
"Requirement already satisfied: googleapis-common-protos<2.0dev,>=1.56.2 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from google-api-core[grpc]<2.0.0dev,>=1.14.0->cirq-google>=0.15.0->openfermion->paddle-quantum) (1.57.0)\n",
"Requirement already satisfied: grpcio<2.0dev,>=1.33.2 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from google-api-core[grpc]<2.0.0dev,>=1.14.0->cirq-google>=0.15.0->openfermion->paddle-quantum) (1.50.0)\n",
"Requirement already satisfied: grpcio-status<2.0dev,>=1.33.2 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from google-api-core[grpc]<2.0.0dev,>=1.14.0->cirq-google>=0.15.0->openfermion->paddle-quantum) (1.48.2)\n",
"Requirement already satisfied: pytz>=2017.3 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from pandas->cirq-core>=0.15.0->openfermion->paddle-quantum) (2022.6)\n",
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from google-auth<3.0dev,>=1.25.0->google-api-core[grpc]<2.0.0dev,>=1.14.0->cirq-google>=0.15.0->openfermion->paddle-quantum) (0.2.8)\n",
"Requirement already satisfied: cachetools<6.0,>=2.0.0 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from google-auth<3.0dev,>=1.25.0->google-api-core[grpc]<2.0.0dev,>=1.14.0->cirq-google>=0.15.0->openfermion->paddle-quantum) (5.2.0)\n",
"Requirement already satisfied: rsa<5,>=3.1.4 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from google-auth<3.0dev,>=1.25.0->google-api-core[grpc]<2.0.0dev,>=1.14.0->cirq-google>=0.15.0->openfermion->paddle-quantum) (4.9)\n",
"Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages (from pyasn1-modules>=0.2.1->google-auth<3.0dev,>=1.25.0->google-api-core[grpc]<2.0.0dev,>=1.14.0->cirq-google>=0.15.0->openfermion->paddle-quantum) (0.4.8)\n",
"Installing collected packages: typing-extensions, scipy, protobuf, qcompute\n",
" Attempting uninstall: typing-extensions\n",
" Found existing installation: typing_extensions 4.3.0\n",
" Uninstalling typing_extensions-4.3.0:\n",
" Successfully uninstalled typing_extensions-4.3.0\n",
" Attempting uninstall: scipy\n",
" Found existing installation: scipy 1.7.1\n",
" Uninstalling scipy-1.7.1:\n",
" Successfully uninstalled scipy-1.7.1\n",
" Attempting uninstall: protobuf\n",
" Found existing installation: protobuf 4.21.1\n",
" Uninstalling protobuf-4.21.1:\n",
" Successfully uninstalled protobuf-4.21.1\n",
" Attempting uninstall: qcompute\n",
" Found existing installation: qcompute 3.0.0\n",
" Uninstalling qcompute-3.0.0:\n",
" Successfully uninstalled qcompute-3.0.0\n",
"Successfully installed protobuf-3.19.6 qcompute-2.0.4 scipy-1.1.0 typing-extensions-3.10.0.0\n",
"Note: you may need to restart the kernel to use updated packages.\n",
"--2022-11-24 13:44:45-- https://release-data.cdn.bcebos.com/PaddleQuantum/vsql.pdparams\n",
"Resolving release-data.cdn.bcebos.com (release-data.cdn.bcebos.com)... 222.35.73.1\n",
"Connecting to release-data.cdn.bcebos.com (release-data.cdn.bcebos.com)|222.35.73.1|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 857 [application/octet-stream]\n",
"Saving to: ‘vsql.pdparams.1’\n",
"\n",
"vsql.pdparams.1 100%[===================>] 857 --.-KB/s in 0s \n",
"\n",
"2022-11-24 13:44:46 (204 MB/s) - ‘vsql.pdparams.1’ saved [857/857]\n",
"\n"
]
}
],
"outputs": [],
"source": [
"# 安装量桨\n",
"%pip install paddle-quantum\n",
"%pip install --user paddle-quantum\n",
"# 下载预训练模型\n",
"!wget https://release-data.cdn.bcebos.com/PaddleQuantum/vsql.pdparams"
"!wget https://release-data.cdn.bcebos.com/PaddleQuantum/vsql.pdparams -O vsql.pdparams"
]
},
{
"cell_type": "markdown",
"id": "f8a3ebf3",
"id": "9be7e030",
"metadata": {},
"source": [
"接下来,可以加载模型并进行测试:"
"成功安装量桨之后,我们来加载 VSQL 模型和要预测的图片:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "53a5f59a",
"id": "3e883450",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages/scipy/linalg/__init__.py:212: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
" from numpy.dual import register_func\n",
"/Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages/scipy/sparse/sputils.py:16: DeprecationWarning: `np.typeDict` is a deprecated alias for `np.sctypeDict`.\n",
" supported_dtypes = [np.typeDict[x] for x in supported_dtypes]\n",
"/Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n",
"/Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages/scipy/io/matlab/mio5.py:98: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from .mio5_utils import VarReader5\n"
]
}
],
"outputs": [],
"source": [
"# 导入所需要的包\n",
"import os\n",
"import warnings\n",
"\n",
"warnings.filterwarnings('ignore')\n",
"os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'\n",
"\n",
"import numpy as np\n",
"import paddle\n",
"import paddle_quantum as pq\n",
"import toml\n",
"import matplotlib.pyplot as plt\n",
"from paddle_quantum.qml.vsql import VSQL\n",
"\n",
"# 设置模型参数\n",
"num_qubits = 10\n",
"num_shadow = 2\n",
"classes = [0, 1]\n",
"num_classes = len(classes)\n",
"depth = 1\n",
"\n",
"# 加载已训练的模型\n",
"model = VSQL(\n",
" num_qubits=num_qubits,\n",
" num_shadow=num_shadow,\n",
" num_classes=num_classes,\n",
" depth=depth,\n",
")\n",
"state_dict = paddle.load('./vsql.pdparams')\n",
"model.set_state_dict(state_dict)"
"from paddle_quantum.qml.vsql import inference"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8f54b4de",
"id": "1ac0fbb3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-11-24 13:45:04-- https://ai-studio-static-online.cdn.bcebos.com/088dc9dbabf349c88d029dfd2e07827aa6e41ba958c5434bbd96bc167fc65347\n",
"Resolving ai-studio-static-online.cdn.bcebos.com (ai-studio-static-online.cdn.bcebos.com)... 222.35.73.1\n",
"Connecting to ai-studio-static-online.cdn.bcebos.com (ai-studio-static-online.cdn.bcebos.com)|222.35.73.1|:443... connected.\n",
"--2023-01-18 15:25:40-- https://ai-studio-static-online.cdn.bcebos.com/088dc9dbabf349c88d029dfd2e07827aa6e41ba958c5434bbd96bc167fc65347\n",
"Resolving ai-studio-static-online.cdn.bcebos.com (ai-studio-static-online.cdn.bcebos.com)... 112.132.208.35, 116.177.239.35, 119.188.176.35, ...\n",
"Connecting to ai-studio-static-online.cdn.bcebos.com (ai-studio-static-online.cdn.bcebos.com)|112.132.208.35|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 290 [image/png]\n",
"Saving to: ‘data-0.png’\n",
"Saving to: ‘data_0.png’\n",
"\n",
"data-0.png 100%[===================>] 290 --.-KB/s in 0s \n",
"data_0.png 100%[===================>] 290 --.-KB/s in 0s \n",
"\n",
"2022-11-24 13:45:05 (277 MB/s) - ‘data-0.png’ saved [290/290]\n",
"2023-01-18 15:25:40 (5.66 KB/s) - ‘data_0.png’ saved [290/290]\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fe752d095d0>"
"<matplotlib.image.AxesImage at 0x7fe0b4d39fd0>"
]
},
"execution_count": 3,
......@@ -337,41 +183,113 @@
],
"source": [
"# 加载手写数字0\n",
"!wget https://ai-studio-static-online.cdn.bcebos.com/088dc9dbabf349c88d029dfd2e07827aa6e41ba958c5434bbd96bc167fc65347 -O data-0.png\n",
"image0 = plt.imread('data-0.png')\n",
"!wget https://ai-studio-static-online.cdn.bcebos.com/088dc9dbabf349c88d029dfd2e07827aa6e41ba958c5434bbd96bc167fc65347 -O data_0.png\n",
"image0 = plt.imread('data_0.png')\n",
"plt.imshow(image0)"
]
},
{
"cell_type": "markdown",
"id": "13aea688",
"metadata": {},
"source": [
"接下来,我们来配置模型参数:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "40ebcb55",
"id": "aedae237",
"metadata": {},
"outputs": [],
"source": [
"test_toml = r\"\"\"\n",
"# 模型的整体配置文件。\n",
"# 输入当前的任务,可以是 'train' 或者 'test',分别代表训练和预测。这里我们使用 test,表示我们要进行预测。\n",
"task = 'test'\n",
"# 要预测的图片的文件路径。\n",
"image_path = 'data_0.png'\n",
"# 上面的图片路径是否是文件夹。对于文件夹路径,我们会对文件夹里面的所有图片文件进行预测。这种方式可以一次测试多个图片。\n",
"is_dir = false\n",
"# 训练好的模型参数文件的文件路径。\n",
"model_path = 'vsql.pdparams'\n",
"# 量子电路所包含的量子比特的数量。\n",
"num_qubits = 10\n",
"# 影子电路所包含的量子比特的数量。\n",
"num_shadow = 2\n",
"# 电路深度。\n",
"depth = 1\n",
"# 我们要预测的类别。这里我们对 0 和 1 进行分类。\n",
"classes = [0, 1]\n",
"\"\"\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "abf2951f",
"metadata": {},
"source": [
"然后,我们使用 VSQL 模型来进行预测。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "316ff898",
"metadata": {},
"outputs": [],
"source": [
"config = toml.loads(test_toml)\n",
"task = config.pop('task')\n",
"prediction, prob = inference(**config)\n",
"prob = prob[0]\n",
"msg = '对于输入的图片,模型有'\n",
"for idx, item in enumerate(prob):\n",
" label = config['classes'][idx]\n",
" msg += f'{item:3.2%} 的信心认为它是 {label:d}'\n",
" msg += '。' if idx == len(prob) - 1 else ','\n",
"print(msg)"
]
},
{
"cell_type": "markdown",
"id": "cbd83c5b",
"metadata": {},
"source": [
"接下来,我们来测试另外一个图片:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "fb4c9ced",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-11-24 13:45:06-- https://ai-studio-static-online.cdn.bcebos.com/c755f723af3d4a1c8f113f8ac3bd365406decd1be70944b7b7b9d41413e8bc7a\n",
"Resolving ai-studio-static-online.cdn.bcebos.com (ai-studio-static-online.cdn.bcebos.com)... 222.35.73.1\n",
"Connecting to ai-studio-static-online.cdn.bcebos.com (ai-studio-static-online.cdn.bcebos.com)|222.35.73.1|:443... connected.\n",
"--2023-01-18 15:25:46-- https://ai-studio-static-online.cdn.bcebos.com/c755f723af3d4a1c8f113f8ac3bd365406decd1be70944b7b7b9d41413e8bc7a\n",
"Resolving ai-studio-static-online.cdn.bcebos.com (ai-studio-static-online.cdn.bcebos.com)... 116.95.27.35, 116.177.239.35, 119.188.176.35, ...\n",
"Connecting to ai-studio-static-online.cdn.bcebos.com (ai-studio-static-online.cdn.bcebos.com)|116.95.27.35|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 173 [image/png]\n",
"Saving to: ‘data-1.png’\n",
"Saving to: ‘data_1.png’\n",
"\n",
"data-1.png 100%[===================>] 173 --.-KB/s in 0s \n",
"data_1.png 100%[===================>] 173 --.-KB/s in 0s \n",
"\n",
"2022-11-24 13:45:06 (165 MB/s) - ‘data-1.png’ saved [173/173]\n",
"2023-01-18 15:25:46 (82.5 MB/s) - ‘data_1.png’ saved [173/173]\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fe77301dd10>"
"<matplotlib.image.AxesImage at 0x7fe0b4f24eb0>"
]
},
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
......@@ -388,48 +306,60 @@
],
"source": [
"# 加载手写数字1\n",
"!wget https://ai-studio-static-online.cdn.bcebos.com/c755f723af3d4a1c8f113f8ac3bd365406decd1be70944b7b7b9d41413e8bc7a -O data-1.png\n",
"image1 = plt.imread('data-1.png')\n",
"!wget https://ai-studio-static-online.cdn.bcebos.com/c755f723af3d4a1c8f113f8ac3bd365406decd1be70944b7b7b9d41413e8bc7a -O data_1.png\n",
"image1 = plt.imread('data_1.png')\n",
"plt.imshow(image1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c4830c1a",
"execution_count": 6,
"id": "ecb2ba3d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages/paddle/tensor/creation.py:125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" if data.dtype == np.object:\n",
"/Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages/paddle/fluid/framework.py:1104: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" elif dtype == np.bool:\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"对于手写数字 0,模型有 89.22% 的信心认为它是 0,有10.78%的信心认为它是 1。\n",
"对于手写数字 1,模型有 18.29% 的信心认为它是 0,有81.71%的信心认为它是 1。\n"
"对于输入的图片,模型有18.29% 的信心认为它是 0,81.71% 的信心认为它是 1。\n"
]
}
],
"source": [
"# 将图片编码为量子态\n",
"test_data = [np.array(image0).flatten(), np.array(image1).flatten()]\n",
"test_data = [np.pad(datum, pad_width=(0, 2 ** num_qubits - datum.size)) for datum in test_data]\n",
"test_data = [paddle.to_tensor(datum / np.linalg.norm(datum), dtype=pq.get_dtype()) for datum in test_data]\n",
"# 使用模型进行预测并得到对应的概率值\n",
"test_output = model(test_data)\n",
"test_prob = paddle.nn.functional.softmax(test_output)\n",
"print(f\"对于手写数字 0,模型有 {test_prob[0][0].item():3.2%} 的信心认为它是 0,有{test_prob[0][1].item():3.2%}的信心认为它是 1。\")\n",
"print(f\"对于手写数字 1,模型有 {test_prob[1][0].item():3.2%} 的信心认为它是 0,有{test_prob[1][1].item():3.2%}的信心认为它是 1。\")"
"test_toml = r\"\"\"\n",
"# 模型的整体配置文件。\n",
"# 输入当前的任务,可以是 'train' 或者 'test',分别代表训练和预测。这里我们使用 test,表示我们要进行预测。\n",
"task = 'test'\n",
"# 要预测的图片的文件路径。\n",
"image_path = 'data_1.png'\n",
"# 上面的图片路径是否是文件夹。对于文件夹路径,我们会对文件夹里面的所有图片文件进行预测。这种方式可以一次测试多个图片。\n",
"is_dir = false\n",
"# 训练好的模型参数文件的文件路径。\n",
"model_path = 'vsql.pdparams'\n",
"# 量子电路所包含的量子比特的数量。\n",
"num_qubits = 10\n",
"# 影子电路所包含的量子比特的数量。\n",
"num_shadow = 2\n",
"# 电路深度。\n",
"depth = 1\n",
"# 我们要预测的类别。这里我们对 0 和 1 进行分类。\n",
"classes = [0, 1]\n",
"\"\"\"\n",
"\n",
"config = toml.loads(test_toml)\n",
"task = config.pop('task')\n",
"# 代码还需要修改\n",
"prediction, prob = inference(**config)\n",
"if config['is_dir']:\n",
" print(f\"对输入图片的预测结果分别是 {str(prediction)[1:-1]}。\")\n",
"else:\n",
" prob = prob[0]\n",
" msg = '对于输入的图片,模型有'\n",
" for idx, item in enumerate(prob):\n",
" label = config['classes'][idx]\n",
" msg += f'{item:3.2%} 的信心认为它是 {label:d}'\n",
" msg += '。' if idx == len(prob) - 1 else ','\n",
" print(msg)"
]
},
{
......@@ -439,7 +369,11 @@
"source": [
"## 5. 注意事项\n",
"\n",
"我们提供的模型为二分类模型,仅可以用来分辨手写数字0和1。对于其它分类任务,需要重新进行训练。"
"我们提供的模型为二分类模型,仅可以用来分辨手写数字0和1。对于其它分类任务,需要重新进行训练。\n",
"\n",
"更详细的使用介绍可以参考:https://github.com/PaddlePaddle/Quantum/blob/master/applications/handwritten_digits_classification/introduction_cn.ipynb\n",
"\n",
"VSQL 模型的具体介绍可以参考:https://github.com/PaddlePaddle/Quantum/blob/master/tutorials/machine_learning/VSQL_CN.ipynb"
]
},
{
......@@ -465,7 +399,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.15 ('py37')",
"display_name": "pq-dev",
"language": "python",
"name": "python3"
},
......@@ -479,7 +413,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.15"
"version": "3.8.15"
},
"toc": {
"base_numbering": 1,
......@@ -496,7 +430,7 @@
},
"vscode": {
"interpreter": {
"hash": "49b49097121cb1ab3a8a640b71467d7eda4aacc01fc9ff84d52fcb3bd4007bf1"
"hash": "5fea01cac43c34394d065c23bb8c1e536fdb97a765a18633fd0c4eb359001810"
}
}
},
......
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "6f0a162f",
"metadata": {},
"source": [
"## 1. VSQL Introduction\n",
"\n",
"Variational Shadow Quantum Learning (VSQL) is a hybird quantum-classical framework for supervised quantum learning, which utilizes parameterized quantum circuits and classical shadows. Unlike commonly used variational quantum algorithms, the VSQL method extracts \"local\" features from the subspace instead of the whole Hilbert space."
"Variational Shadow Quantum Learning (VSQL) is a hybrid quantum-classical framework for supervised quantum learning, which utilizes parameterized quantum circuits and classical shadows. Unlike commonly used variational quantum algorithms, the VSQL method extracts \"local\" features from the subspace instead of the whole Hilbert space."
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "99c07da5",
"metadata": {},
......@@ -65,12 +67,12 @@
"![2-local](https://ai-studio-static-online.cdn.bcebos.com/0c1035262cb64f61bd3cc87dbf53253aa6a7ecc170634c4db8dd71d576a9409c \"The 2-local shadow circuit design\")\n",
"<div style=\"text-align:center\">The 2-local shadow circuit design</div>\n",
"\n",
"The circuit layer in the dashed box is repeated for $D$ times to increase the expressive power of the quantum circuit. The structure of the circuit is not unique. You can try to design your own circuit."
"The circuit layer in the dashed box is repeated for $D$ times to increase the expressive power of the quantum circuit. The structure of the circuit is not unique. You can try to design your own circuit.\n"
]
},
{
"cell_type": "markdown",
"id": "bbec4432",
"id": "cf1da740",
"metadata": {},
"source": [
"## 3. Model Performance\n",
......@@ -100,125 +102,71 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "77177110",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-11-24 13:45:54-- https://release-data.cdn.bcebos.com/PaddleQuantum/vsql.pdparams\n",
"Resolving release-data.cdn.bcebos.com (release-data.cdn.bcebos.com)... 222.35.73.1\n",
"Connecting to release-data.cdn.bcebos.com (release-data.cdn.bcebos.com)|222.35.73.1|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 857 [application/octet-stream]\n",
"Saving to: ‘vsql.pdparams.2’\n",
"\n",
"vsql.pdparams.2 100%[===================>] 857 --.-KB/s in 0s \n",
"\n",
"2022-11-24 13:45:54 (817 MB/s) - ‘vsql.pdparams.2’ saved [857/857]\n",
"\n"
]
}
],
"outputs": [],
"source": [
"# Install the paddle quantum\n",
"%pip install paddle-quantum\n",
"# Download the pretrained model\n",
"!wget https://release-data.cdn.bcebos.com/PaddleQuantum/vsql.pdparams"
"%pip install --user paddle-quantum\n",
"# Download the pre-trained model\n",
"!wget https://release-data.cdn.bcebos.com/PaddleQuantum/vsql.pdparams -O vsql.pdparams"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "4843c62f",
"id": "2a8ef70f",
"metadata": {},
"source": [
"Next, the model can be loaded and tested."
"After installing Paddle Quantum successfully, let's load the VSQL model and the images to be predicted."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "86d4405c",
"id": "7d88445e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages/scipy/linalg/__init__.py:212: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
" from numpy.dual import register_func\n",
"/Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages/scipy/sparse/sputils.py:16: DeprecationWarning: `np.typeDict` is a deprecated alias for `np.sctypeDict`.\n",
" supported_dtypes = [np.typeDict[x] for x in supported_dtypes]\n",
"/Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n",
"/Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages/scipy/io/matlab/mio5.py:98: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from .mio5_utils import VarReader5\n"
]
}
],
"outputs": [],
"source": [
"# Import the required packages\n",
"import os\n",
"import warnings\n",
"\n",
"warnings.filterwarnings('ignore')\n",
"os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'\n",
"\n",
"import numpy as np\n",
"import paddle\n",
"import paddle_quantum as pq\n",
"import toml\n",
"import matplotlib.pyplot as plt\n",
"from paddle_quantum.qml.vsql import VSQL\n",
"\n",
"# Set model parameters\n",
"num_qubits = 10\n",
"num_shadow = 2\n",
"classes = [0, 1]\n",
"num_classes = len(classes)\n",
"depth = 1\n",
"\n",
"# Load the trained model\n",
"model = VSQL(\n",
" num_qubits=num_qubits,\n",
" num_shadow=num_shadow,\n",
" num_classes=num_classes,\n",
" depth=depth,\n",
")\n",
"state_dict = paddle.load('./vsql.pdparams')\n",
"model.set_state_dict(state_dict)"
"from paddle_quantum.qml.vsql import inference"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6676b204",
"id": "d12ea921",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-11-24 13:46:01-- https://ai-studio-static-online.cdn.bcebos.com/088dc9dbabf349c88d029dfd2e07827aa6e41ba958c5434bbd96bc167fc65347\n",
"Resolving ai-studio-static-online.cdn.bcebos.com (ai-studio-static-online.cdn.bcebos.com)... 222.35.73.1\n",
"Connecting to ai-studio-static-online.cdn.bcebos.com (ai-studio-static-online.cdn.bcebos.com)|222.35.73.1|:443... connected.\n",
"--2023-01-18 15:24:03-- https://ai-studio-static-online.cdn.bcebos.com/088dc9dbabf349c88d029dfd2e07827aa6e41ba958c5434bbd96bc167fc65347\n",
"Resolving ai-studio-static-online.cdn.bcebos.com (ai-studio-static-online.cdn.bcebos.com)... 119.167.254.35, 153.35.89.225, 211.97.83.35, ...\n",
"Connecting to ai-studio-static-online.cdn.bcebos.com (ai-studio-static-online.cdn.bcebos.com)|119.167.254.35|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 290 [image/png]\n",
"Saving to: ‘data-0.png’\n",
"Saving to: ‘data_0.png’\n",
"\n",
"data-0.png 100%[===================>] 290 --.-KB/s in 0s \n",
"data_0.png 100%[===================>] 290 --.-KB/s in 0s \n",
"\n",
"2022-11-24 13:46:02 (138 MB/s) - ‘data-0.png’ saved [290/290]\n",
"2023-01-18 15:24:03 (138 MB/s) - ‘data_0.png’ saved [290/290]\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fa1199f8710>"
"<matplotlib.image.AxesImage at 0x7f81fc0a0fd0>"
]
},
"execution_count": 3,
......@@ -238,41 +186,125 @@
],
"source": [
"# Load handwritten digit 0\n",
"!wget https://ai-studio-static-online.cdn.bcebos.com/088dc9dbabf349c88d029dfd2e07827aa6e41ba958c5434bbd96bc167fc65347 -O data-0.png\n",
"image0 = plt.imread('data-0.png')\n",
"!wget https://ai-studio-static-online.cdn.bcebos.com/088dc9dbabf349c88d029dfd2e07827aa6e41ba958c5434bbd96bc167fc65347 -O data_0.png\n",
"image0 = plt.imread('data_0.png')\n",
"plt.imshow(image0)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "34dcf660",
"metadata": {},
"source": [
"Next, let's configure the model parameters."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f637d0ca",
"id": "52d92fdf",
"metadata": {},
"outputs": [],
"source": [
"test_toml = r\"\"\"\n",
"# The overall configuration file of the model.\n",
"# Enter the current task, which can be 'train' or 'test', representing training and prediction respectively. Here we use test, indicating that we want to make a prediction.\n",
"task = 'test'\n",
"# The file path of the image to be predicted.\n",
"image_path = 'data_0.png'\n",
"# Whether the image path above is a folder or not. For folder paths, we will predict all image files inside the folder. This way you can test multiple images at once.\n",
"is_dir = false\n",
"# The file path of the trained model parameter file.\n",
"model_path = 'vsql.pdparams'\n",
"# The number of qubits that the quantum circuit contains.\n",
"num_qubits = 10\n",
"# The number of qubits that the shadow circuit contains.\n",
"num_shadow = 2\n",
"# Circuit depth.\n",
"depth = 1\n",
"# The class to be predicted by the model. Here, 0 and 1 are classified.\n",
"classes = [0, 1]\n",
"\"\"\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "35514ce7",
"metadata": {},
"source": [
"Then, we use the VSQL model to make predictions."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ca32fb07",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"For the input image, the model has 89.22% confidence that it is 0, and 10.78% confidence that it is 1.\n"
]
}
],
"source": [
"config = toml.loads(test_toml)\n",
"task = config.pop('task')\n",
"prediction, prob = inference(**config)\n",
"prob = prob[0]\n",
"msg = 'For the input image, the model has'\n",
"for idx, item in enumerate(prob):\n",
" if idx == len(prob) - 1:\n",
" msg += 'and'\n",
" label = config['classes'][idx]\n",
" msg += f' {item:3.2%} confidence that it is {label:d}'\n",
" msg += '.' if idx == len(prob) - 1 else ', '\n",
"print(msg)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "390302db",
"metadata": {},
"source": [
"Next, let's test another image."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4b0b6abe",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-11-24 13:46:03-- https://ai-studio-static-online.cdn.bcebos.com/c755f723af3d4a1c8f113f8ac3bd365406decd1be70944b7b7b9d41413e8bc7a\n",
"Resolving ai-studio-static-online.cdn.bcebos.com (ai-studio-static-online.cdn.bcebos.com)... 222.35.73.1\n",
"Connecting to ai-studio-static-online.cdn.bcebos.com (ai-studio-static-online.cdn.bcebos.com)|222.35.73.1|:443... connected.\n",
"--2023-01-18 15:24:12-- https://ai-studio-static-online.cdn.bcebos.com/c755f723af3d4a1c8f113f8ac3bd365406decd1be70944b7b7b9d41413e8bc7a\n",
"Resolving ai-studio-static-online.cdn.bcebos.com (ai-studio-static-online.cdn.bcebos.com)... 119.176.25.35, 153.35.89.225, 211.97.83.35, ...\n",
"Connecting to ai-studio-static-online.cdn.bcebos.com (ai-studio-static-online.cdn.bcebos.com)|119.176.25.35|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 173 [image/png]\n",
"Saving to: ‘data-1.png’\n",
"Saving to: ‘data_1.png’\n",
"\n",
"data-1.png 100%[===================>] 173 --.-KB/s in 0s \n",
"data_1.png 100%[===================>] 173 --.-KB/s in 0s \n",
"\n",
"2022-11-24 13:46:03 (165 MB/s) - ‘data-1.png’ saved [173/173]\n",
"2023-01-18 15:24:12 (3.38 KB/s) - ‘data_1.png’ saved [173/173]\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fa1008969d0>"
"<matplotlib.image.AxesImage at 0x7f81dd91eb50>"
]
},
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
......@@ -289,66 +321,73 @@
],
"source": [
"# Load handwritten digit 1\n",
"!wget https://ai-studio-static-online.cdn.bcebos.com/c755f723af3d4a1c8f113f8ac3bd365406decd1be70944b7b7b9d41413e8bc7a -O data-1.png\n",
"image1 = plt.imread('data-1.png')\n",
"!wget https://ai-studio-static-online.cdn.bcebos.com/c755f723af3d4a1c8f113f8ac3bd365406decd1be70944b7b7b9d41413e8bc7a -O data_1.png\n",
"image1 = plt.imread('data_1.png')\n",
"plt.imshow(image1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e40d847a",
"execution_count": 7,
"id": "a5bcba99",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages/paddle/tensor/creation.py:125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" if data.dtype == np.object:\n",
"/Users/wangzihe/opt/anaconda3/envs/py37/lib/python3.7/site-packages/paddle/fluid/framework.py:1104: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" elif dtype == np.bool:\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"For handwritten digits 0, the model has 89.22% confidence that it is 0 and 10.78% confidence that it is 1.\n",
"For handwritten digits 1, the model has 18.29% confidence that it is 0 and 81.71% confidence that it is 1.\n"
"For the input image, the model has 18.29% confidence that it is 0, and 81.71% confidence that it is 1.\n"
]
}
],
"source": [
"# Encoding images into quantum states\n",
"test_data = [np.array(image0).flatten(), np.array(image1).flatten()]\n",
"test_data = [np.pad(datum, pad_width=(0, 2 ** num_qubits - datum.size)) for datum in test_data]\n",
"test_data = [paddle.to_tensor(datum / np.linalg.norm(datum), dtype=pq.get_dtype()) for datum in test_data]\n",
"# Use the model to make predictions and get the corresponding probability\n",
"test_output = model(test_data)\n",
"test_prob = paddle.nn.functional.softmax(test_output)\n",
"print(\n",
" f\"For handwritten digits 0, \"\n",
" f\"the model has {test_prob[0][0].item():3.2%} confidence that it is 0 \"\n",
" f\"and {test_prob[0][1].item():3.2%} confidence that it is 1.\"\n",
")\n",
"print(\n",
" f\"For handwritten digits 1, \"\n",
" f\"the model has {test_prob[1][0].item():3.2%} confidence that it is 0 \"\n",
" f\"and {test_prob[1][1].item():3.2%} confidence that it is 1.\"\n",
")"
"test_toml = r\"\"\"\n",
"# The overall configuration file of the model.\n",
"# Enter the current task, which can be 'train' or 'test', representing training and prediction respectively. Here we use test, indicating that we want to make a prediction.\n",
"task = 'test'\n",
"# The file path of the image to be predicted.\n",
"image_path = 'data_1.png'\n",
"# Whether the image path above is a folder or not. For folder paths, we will predict all image files inside the folder. This way you can test multiple images at once.\n",
"is_dir = false\n",
"# The file path of the trained model parameter file.\n",
"model_path = 'vsql.pdparams'\n",
"# The number of qubits that the quantum circuit contains.\n",
"num_qubits = 10\n",
"# The number of qubits that the shadow circuit contains.\n",
"num_shadow = 2\n",
"# Circuit depth.\n",
"depth = 1\n",
"# The class to be predicted by the model. Here, 0 and 1 are classified.\n",
"classes = [0, 1]\n",
"\"\"\"\n",
"\n",
"config = toml.loads(test_toml)\n",
"task = config.pop('task')\n",
"prediction, prob = inference(**config)\n",
"prob = prob[0]\n",
"msg = 'For the input image, the model has'\n",
"for idx, item in enumerate(prob):\n",
" if idx == len(prob) - 1:\n",
" msg += 'and'\n",
" label = config['classes'][idx]\n",
" msg += f' {item:3.2%} confidence that it is {label:d}'\n",
" msg += '.' if idx == len(prob) - 1 else ', '\n",
"print(msg)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "3990efae",
"metadata": {},
"source": [
"## 5. Note\n",
"\n",
"The model we provide is a binary classification model that can only be used to distinguish handwritten digits 0 and 1. For other classification tasks, it needs to be retrained."
"The model we provide is a binary classification model that can only be used to distinguish handwritten digits 0 and 1. For other classification tasks, it needs to be retrained.\n",
"\n",
"A more detailed description of the use can be found at https://github.com/PaddlePaddle/Quantum/blob/master/applications/handwritten_digits_classification/introduction_en.ipynb .\n",
"\n",
"A detailed description of the VSQL model can be found at https://github.com/PaddlePaddle/Quantum/blob/master/tutorials/machine_learning/VSQL_EN.ipynb ."
]
},
{
......@@ -374,7 +413,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.15 ('py37')",
"display_name": "pq-dev",
"language": "python",
"name": "python3"
},
......@@ -388,7 +427,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.15"
"version": "3.8.15 (default, Nov 10 2022, 13:17:42) \n[Clang 14.0.6 ]"
},
"toc": {
"base_numbering": 1,
......@@ -405,7 +444,7 @@
},
"vscode": {
"interpreter": {
"hash": "49b49097121cb1ab3a8a640b71467d7eda4aacc01fc9ff84d52fcb3bd4007bf1"
"hash": "5fea01cac43c34394d065c23bb8c1e536fdb97a765a18633fd0c4eb359001810"
}
}
},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册