diff --git a/demos/audio_searching/README.md b/demos/audio_searching/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0ee781ad26a140b63421fd51dc37500ead44ec84 --- /dev/null +++ b/demos/audio_searching/README.md @@ -0,0 +1,95 @@ +([简体中文](./README_cn.md)|English) + +# Audio Searching + +## Introduction +This demo uses ECAPA-TDNN(or other models) for Speaker Recognition base on MySQL to store user-info/id and Milvus to search vectors. + +## Usage +### 1. Prepare MySQL and Milvus services by docker-compose +The molecular similarity search system requires Milvus, MySQL services. We can start these containers with one click through [docker-compose.yaml](./docker-compose.yaml), so please make sure you have [installed Docker Engine](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/) before running. then + +```bash +docker-compose -f docker-compose.yaml up -d +``` + +Then you will see the that all containers are created: + +```bash +Creating network "quick_deploy_app_net" with driver "bridge" +Creating milvus-minio ... done +Creating milvus-etcd ... done +Creating audio-mysql ... done +Creating milvus-standalone ... done +``` + +And show all containers with `docker ps`, and you can use `docker logs audio-mysql` to get the logs of server container + +```bash +CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES +b2bcf279e599 milvusdb/milvus:v2.0.1 "/tini -- milvus run…" 22 hours ago Up 22 hours 0.0.0.0:19530->19530/tcp milvus-standalone +d8ef4c84e25c mysql:5.7 "docker-entrypoint.s…" 22 hours ago Up 22 hours 0.0.0.0:3306->3306/tcp, 33060/tcp audio-mysql +8fb501edb4f3 quay.io/coreos/etcd:v3.5.0 "etcd -advertise-cli…" 22 hours ago Up 22 hours 2379-2380/tcp milvus-etcd +ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…" 22 hours ago Up 22 hours (healthy) 9000/tcp milvus-minio + +``` + +### 2. Start API Server +Then to start the system server, and it provides HTTP backend services. + +- Install the Python packages + +```bash +pip install -r requirements.txt +``` +- Set configuration + +```bash +vim src/config.py +``` + +Modify the parameters according to your own environment. Here listing some parameters that need to be set, for more information please refer to [config.py](./src/config.py). + +| **Parameter** | **Description** | **Default setting** | +| ---------------- | ----------------------------------------------------- | ------------------- | +| MILVUS_HOST | The IP address of Milvus, you can get it by ifconfig. If running everything on one machine, most likely 127.0.0.1 | 127.0.0.1 | +| MILVUS_PORT | Port of Milvus. | 19530 | +| VECTOR_DIMENSION | Dimension of the vectors. | 2048 | +| MYSQL_HOST | The IP address of Mysql. | 127.0.0.1 | +| MYSQL_PORT | Port of Milvus. | 3306 | +| DEFAULT_TABLE | The milvus and mysql default collection name. | audio_table | + +- Run the code + +Then start the server with Fastapi. + +```bash +python src/main.py +``` + +Then you will see the Application is started: + +```bash +INFO: Started server process [3949] +2022-03-07 17:39:14,864 | INFO | server.py | serve | 75 | Started server process [3949] +INFO: Waiting for application startup. +2022-03-07 17:39:14,865 | INFO | on.py | startup | 45 | Waiting for application startup. +INFO: Application startup complete. +2022-03-07 17:39:14,866 | INFO | on.py | startup | 59 | Application startup complete. +INFO: Uvicorn running on http://127.0.0.1:8002 (Press CTRL+C to quit) +2022-03-07 17:39:14,867 | INFO | server.py | _log_started_message | 206 | Uvicorn running on http://127.0.0.1:8002 (Press CTRL+C to quit) +``` + +### 3. Usage + + ```bash + python ./src/test_main.py + ``` + +### 4.Pretrained Models + +Here is a list of pretrained models released by PaddleSpeech that can be used by command and python API: + +| Model | Sample Rate +| :--- | :---: +| ecapa_tdnn | 16000 diff --git a/demos/audio_searching/docker-compose.yaml b/demos/audio_searching/docker-compose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7e16ee8a0ab5610b517a08359f2b8345fa270b1f --- /dev/null +++ b/demos/audio_searching/docker-compose.yaml @@ -0,0 +1,73 @@ +version: '3.5' + +services: + etcd: + container_name: milvus-etcd + image: quay.io/coreos/etcd:v3.5.0 + networks: + app_net: + environment: + - ETCD_AUTO_COMPACTION_MODE=revision + - ETCD_AUTO_COMPACTION_RETENTION=1000 + - ETCD_QUOTA_BACKEND_BYTES=4294967296 + volumes: + - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd + command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd + + minio: + container_name: milvus-minio + image: minio/minio:RELEASE.2020-12-03T00-03-10Z + networks: + app_net: + environment: + MINIO_ACCESS_KEY: minioadmin + MINIO_SECRET_KEY: minioadmin + volumes: + - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data + command: minio server /minio_data + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] + interval: 30s + timeout: 20s + retries: 3 + + standalone: + container_name: milvus-standalone + image: milvusdb/milvus:v2.0.1 + networks: + app_net: + ipv4_address: 172.16.23.10 + command: ["milvus", "run", "standalone"] + environment: + ETCD_ENDPOINTS: etcd:2379 + MINIO_ADDRESS: minio:9000 + volumes: + - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus + ports: + - "19530:19530" + depends_on: + - "etcd" + - "minio" + + + mysql: + container_name: audio-mysql + image: mysql:5.7 + networks: + app_net: + ipv4_address: 172.16.23.11 + environment: + - MYSQL_ROOT_PASSWORD=123456 + volumes: + - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/mysql:/var/lib/mysql + ports: + - "3306:3306" + +networks: + app_net: + driver: bridge + ipam: + driver: default + config: + - subnet: 172.16.23.0/24 + gateway: 172.16.23.1 diff --git a/demos/audio_searching/requirements.txt b/demos/audio_searching/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9e73361b47327783d58741646d5051fabaebf226 --- /dev/null +++ b/demos/audio_searching/requirements.txt @@ -0,0 +1,12 @@ +soundfile==0.10.3.post1 +librosa==0.8.0 +numpy +pymysql +fastapi +uvicorn +diskcache==5.2.1 +pymilvus==2.0.1 +python-multipart +typing +starlette +pydantic \ No newline at end of file diff --git a/demos/audio_searching/src/config.py b/demos/audio_searching/src/config.py new file mode 100644 index 0000000000000000000000000000000000000000..72a8fb4beadb6fdd3df7801d13d83a42678219d0 --- /dev/null +++ b/demos/audio_searching/src/config.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +############### Milvus Configuration ############### +MILVUS_HOST = os.getenv("MILVUS_HOST", "127.0.0.1") +MILVUS_PORT = int(os.getenv("MILVUS_PORT", "19530")) +VECTOR_DIMENSION = int(os.getenv("VECTOR_DIMENSION", "2048")) +INDEX_FILE_SIZE = int(os.getenv("INDEX_FILE_SIZE", "1024")) +METRIC_TYPE = os.getenv("METRIC_TYPE", "L2") +DEFAULT_TABLE = os.getenv("DEFAULT_TABLE", "audio_table") +TOP_K = int(os.getenv("TOP_K", "10")) + +############### MySQL Configuration ############### +MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1") +MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306")) +MYSQL_USER = os.getenv("MYSQL_USER", "root") +MYSQL_PWD = os.getenv("MYSQL_PWD", "123456") +MYSQL_DB = os.getenv("MYSQL_DB", "mysql") + +############### Data Path ############### +UPLOAD_PATH = os.getenv("UPLOAD_PATH", "tmp/audio-data") + +############### Number of Log Files ############### +LOGS_NUM = int(os.getenv("logs_num", "0")) diff --git a/demos/audio_searching/src/encode.py b/demos/audio_searching/src/encode.py new file mode 100644 index 0000000000000000000000000000000000000000..391822c76da44012c8199d1fbc92acab4b85b100 --- /dev/null +++ b/demos/audio_searching/src/encode.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import librosa +import numpy as np +from logs import LOGGER + + +def get_audio_embedding(path): + """ + Use vpr_inference to generate embedding of audio + """ + try: + RESAMPLE_RATE = 16000 + audio, _ = librosa.load(path, sr=RESAMPLE_RATE, mono=True) + + # TODO add infer/python interface to get embedding, now fake it by rand + # vpr = ECAPATDNN(checkpoint_path=None, device='cuda') + # embedding = vpr.inference(audio) + + embedding = np.random.rand(1, 2048) + embedding = embedding / np.linalg.norm(embedding) + embedding = embedding.tolist()[0] + return embedding + except Exception as e: + LOGGER.error(f"Error with embedding:{e}") + return None diff --git a/demos/audio_searching/src/logs.py b/demos/audio_searching/src/logs.py new file mode 100644 index 0000000000000000000000000000000000000000..ba3ed069c6428797353b1adcdfb0f5b18b02a8ad --- /dev/null +++ b/demos/audio_searching/src/logs.py @@ -0,0 +1,164 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import codecs +import datetime +import logging +import os +import re +import sys + +from config import LOGS_NUM + + +class MultiprocessHandler(logging.FileHandler): + """ + A handler class which writes formatted logging records to disk files + """ + + def __init__(self, + filename, + when='D', + backupCount=0, + encoding=None, + delay=False): + """ + Open the specified file and use it as the stream for logging + """ + self.prefix = filename + self.backupCount = backupCount + self.when = when.upper() + self.extMath = r"^\d{4}-\d{2}-\d{2}" + + self.when_dict = { + 'S': "%Y-%m-%d-%H-%M-%S", + 'M': "%Y-%m-%d-%H-%M", + 'H': "%Y-%m-%d-%H", + 'D': "%Y-%m-%d" + } + + self.suffix = self.when_dict.get(when) + if not self.suffix: + print('The specified date interval unit is invalid: ', self.when) + sys.exit(1) + + self.filefmt = os.path.join('.', "logs", + f"{self.prefix}-{self.suffix}.log") + + self.filePath = datetime.datetime.now().strftime(self.filefmt) + + _dir = os.path.dirname(self.filefmt) + try: + if not os.path.exists(_dir): + os.makedirs(_dir) + except Exception as e: + print('Failed to create log file: ', e) + print("log_path:" + self.filePath) + sys.exit(1) + + logging.FileHandler.__init__(self, self.filePath, 'a+', encoding, delay) + + def should_change_file_to_write(self): + """ + To write the file + """ + _filePath = datetime.datetime.now().strftime(self.filefmt) + if _filePath != self.filePath: + self.filePath = _filePath + return True + return False + + def do_change_file(self): + """ + To change file states + """ + self.baseFilename = os.path.abspath(self.filePath) + if self.stream: + self.stream.close() + self.stream = None + + if not self.delay: + self.stream = self._open() + if self.backupCount > 0: + for s in self.get_files_to_delete(): + os.remove(s) + + def get_files_to_delete(self): + """ + To delete backup files + """ + dir_name, _ = os.path.split(self.baseFilename) + file_names = os.listdir(dir_name) + result = [] + prefix = self.prefix + '-' + for file_name in file_names: + if file_name[:len(prefix)] == prefix: + suffix = file_name[len(prefix):-4] + if re.compile(self.extMath).match(suffix): + result.append(os.path.join(dir_name, file_name)) + result.sort() + + if len(result) < self.backupCount: + result = [] + else: + result = result[:len(result) - self.backupCount] + return result + + def emit(self, record): + """ + Emit a record + """ + try: + if self.should_change_file_to_write(): + self.do_change_file() + logging.FileHandler.emit(self, record) + except (KeyboardInterrupt, SystemExit): + raise + except: + self.handleError(record) + + +def write_log(): + """ + Init a logger + """ + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + # formatter = '%(asctime)s | %(levelname)s | %(filename)s | %(funcName)s | %(module)s | %(lineno)s | %(message)s' + fmt = logging.Formatter( + '%(asctime)s | %(levelname)s | %(filename)s | %(funcName)s | %(lineno)s | %(message)s' + ) + + stream_handler = logging.StreamHandler(sys.stdout) + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(fmt) + + log_name = "audio-searching" + file_handler = MultiprocessHandler(log_name, when='D', backupCount=LOGS_NUM) + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(fmt) + file_handler.do_change_file() + + logger.addHandler(stream_handler) + logger.addHandler(file_handler) + + return logger + + +LOGGER = write_log() + +if __name__ == "__main__": + message = 'test writing logs' + LOGGER.info(message) + LOGGER.debug(message) + LOGGER.error(message) diff --git a/demos/audio_searching/src/main.py b/demos/audio_searching/src/main.py new file mode 100644 index 0000000000000000000000000000000000000000..89c037a0e6adc056c44d3127b9dee5f36b8dc368 --- /dev/null +++ b/demos/audio_searching/src/main.py @@ -0,0 +1,166 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Optional + +import uvicorn +from config import UPLOAD_PATH +from diskcache import Cache +from fastapi import FastAPI +from fastapi import File +from fastapi import UploadFile +from logs import LOGGER +from milvus_helpers import MilvusHelper +from mysql_helpers import MySQLHelper +from operations.count import do_count +from operations.drop import do_drop +from operations.load import do_load +from operations.search import do_search +from pydantic import BaseModel +from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request +from starlette.responses import FileResponse + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"]) + +MODEL = None +MILVUS_CLI = MilvusHelper() +MYSQL_CLI = MySQLHelper() + +# Mkdir 'tmp/audio-data' +if not os.path.exists(UPLOAD_PATH): + os.makedirs(UPLOAD_PATH) + LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}") + + +@app.get('/data') +def audio_path(audio_path): + # Get the audio file + try: + LOGGER.info(f"Successfully load audio: {audio_path}") + return FileResponse(audio_path) + except Exception as e: + LOGGER.error(f"upload audio error: {e}") + return {'status': False, 'msg': e}, 400 + + +@app.get('/progress') +def get_progress(): + # Get the progress of dealing with data + try: + cache = Cache('./tmp') + return f"current: {cache['current']}, total: {cache['total']}" + except Exception as e: + LOGGER.error(f"Upload data error: {e}") + return {'status': False, 'msg': e}, 400 + + +class Item(BaseModel): + Table: Optional[str] = None + File: str + + +@app.post('/audio/load') +async def load_audios(item: Item): + # Insert all the audio files under the file path to Milvus/MySQL + try: + total_num = do_load(item.Table, item.File, MILVUS_CLI, MYSQL_CLI) + LOGGER.info(f"Successfully loaded data, total count: {total_num}") + return {'status': True, 'msg': "Successfully loaded data!"} + except Exception as e: + LOGGER.error(e) + return {'status': False, 'msg': e}, 400 + + +@app.post('/audio/search') +async def search_audio(request: Request, + table_name: str=None, + audio: UploadFile=File(...)): + # Search the uploaded audio in Milvus/MySQL + try: + # Save the upload data to server. + content = await audio.read() + query_audio_path = os.path.join(UPLOAD_PATH, audio.filename) + with open(query_audio_path, "wb+") as f: + f.write(content) + host = request.headers['host'] + _, paths, distances = do_search(host, table_name, query_audio_path, + MILVUS_CLI, MYSQL_CLI) + names = [] + for i in paths: + names.append(os.path.basename(i)) + res = dict(zip(paths, zip(names, distances))) + # Sort results by distance metric, closest distances first + res = sorted(res.items(), key=lambda item: item[1][1]) + LOGGER.info("Successfully searched similar audio!") + return res + except Exception as e: + LOGGER.error(e) + return {'status': False, 'msg': e}, 400 + + +@app.post('/audio/search/local') +async def search_local_audio(request: Request, + query_audio_path: str, + table_name: str=None): + # Search the uploaded audio in Milvus/MySQL + try: + host = request.headers['host'] + _, paths, distances = do_search(host, table_name, query_audio_path, + MILVUS_CLI, MYSQL_CLI) + names = [] + for i in paths: + names.append(os.path.basename(i)) + res = dict(zip(paths, zip(names, distances))) + # Sort results by distance metric, closest distances first + res = sorted(res.items(), key=lambda item: item[1][1]) + LOGGER.info("Successfully searched similar audio!") + return res + except Exception as e: + LOGGER.error(e) + return {'status': False, 'msg': e}, 400 + + +@app.get('/audio/count') +async def count_audio(table_name: str=None): + # Returns the total number of vectors in the system + try: + num = do_count(table_name, MILVUS_CLI) + LOGGER.info("Successfully count the number of data!") + return num + except Exception as e: + LOGGER.error(e) + return {'status': False, 'msg': e}, 400 + + +@app.post('/audio/drop') +async def drop_tables(table_name: str=None): + # Delete the collection of Milvus and MySQL + try: + status = do_drop(table_name, MILVUS_CLI, MYSQL_CLI) + LOGGER.info("Successfully drop tables in Milvus and MySQL!") + return status + except Exception as e: + LOGGER.error(e) + return {'status': False, 'msg': e}, 400 + + +if __name__ == '__main__': + uvicorn.run(app=app, host='127.0.0.1', port=8002) diff --git a/demos/audio_searching/src/milvus_helpers.py b/demos/audio_searching/src/milvus_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..8ba3776be8871af60bd7f91225b218ff30713c4b --- /dev/null +++ b/demos/audio_searching/src/milvus_helpers.py @@ -0,0 +1,186 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +from config import METRIC_TYPE +from config import MILVUS_HOST +from config import MILVUS_PORT +from config import VECTOR_DIMENSION +from logs import LOGGER +from pymilvus import Collection +from pymilvus import CollectionSchema +from pymilvus import connections +from pymilvus import DataType +from pymilvus import FieldSchema +from pymilvus import utility + + +class MilvusHelper: + """ + the basic operations of PyMilvus + + # This example shows how to: + # 1. connect to Milvus server + # 2. create a collection + # 3. insert entities + # 4. create index + # 5. search + # 6. delete a collection + + """ + + def __init__(self): + try: + self.collection = None + connections.connect(host=MILVUS_HOST, port=MILVUS_PORT) + LOGGER.debug( + f"Successfully connect to Milvus with IP:{MILVUS_HOST} and PORT:{MILVUS_PORT}" + ) + except Exception as e: + LOGGER.error(f"Failed to connect Milvus: {e}") + sys.exit(1) + + def set_collection(self, collection_name): + try: + if self.has_collection(collection_name): + self.collection = Collection(name=collection_name) + else: + raise Exception( + f"There is no collection named:{collection_name}") + except Exception as e: + LOGGER.error(f"Failed to load data to Milvus: {e}") + sys.exit(1) + + def has_collection(self, collection_name): + # Return if Milvus has the collection + try: + return utility.has_collection(collection_name) + except Exception as e: + LOGGER.error(f"Failed to load data to Milvus: {e}") + sys.exit(1) + + def create_collection(self, collection_name): + # Create milvus collection if not exists + try: + if not self.has_collection(collection_name): + field1 = FieldSchema( + name="id", + dtype=DataType.INT64, + descrition="int64", + is_primary=True, + auto_id=True) + field2 = FieldSchema( + name="embedding", + dtype=DataType.FLOAT_VECTOR, + descrition="speaker embeddings", + dim=VECTOR_DIMENSION, + is_primary=False) + schema = CollectionSchema( + fields=[field1, field2], description="embeddings info") + self.collection = Collection( + name=collection_name, schema=schema) + LOGGER.debug(f"Create Milvus collection: {collection_name}") + else: + self.set_collection(collection_name) + return "OK" + except Exception as e: + LOGGER.error(f"Failed to load data to Milvus: {e}") + sys.exit(1) + + def insert(self, collection_name, vectors): + # Batch insert vectors to milvus collection + try: + self.create_collection(collection_name) + data = [vectors] + self.set_collection(collection_name) + mr = self.collection.insert(data) + ids = mr.primary_keys + self.collection.load() + LOGGER.debug( + f"Insert vectors to Milvus in collection: {collection_name} with {len(vectors)} rows" + ) + return ids + except Exception as e: + LOGGER.error(f"Failed to load data to Milvus: {e}") + sys.exit(1) + + def create_index(self, collection_name): + # Create IVF_FLAT index on milvus collection + try: + self.set_collection(collection_name) + default_index = { + "index_type": "IVF_SQ8", + "metric_type": METRIC_TYPE, + "params": { + "nlist": 16384 + } + } + status = self.collection.create_index( + field_name="embedding", index_params=default_index) + if not status.code: + LOGGER.debug( + f"Successfully create index in collection:{collection_name} with param:{default_index}" + ) + return status + else: + raise Exception(status.message) + except Exception as e: + LOGGER.error(f"Failed to create index: {e}") + sys.exit(1) + + def delete_collection(self, collection_name): + # Delete Milvus collection + try: + self.set_collection(collection_name) + self.collection.drop() + LOGGER.debug("Successfully drop collection!") + return "ok" + except Exception as e: + LOGGER.error(f"Failed to drop collection: {e}") + sys.exit(1) + + def search_vectors(self, collection_name, vectors, top_k): + # Search vector in milvus collection + try: + self.set_collection(collection_name) + search_params = { + "metric_type": METRIC_TYPE, + "params": { + "nprobe": 16 + } + } + # data = [vectors] + res = self.collection.search( + vectors, + anns_field="embedding", + param=search_params, + limit=top_k) + LOGGER.debug(f"Successfully search in collection: {res}") + return res + except Exception as e: + LOGGER.error(f"Failed to search vectors in Milvus: {e}") + sys.exit(1) + + def count(self, collection_name): + # Get the number of milvus collection + try: + self.set_collection(collection_name) + num = self.collection.num_entities + LOGGER.debug( + f"Successfully get the num:{num} of the collection:{collection_name}" + ) + return num + except Exception as e: + LOGGER.error(f"Failed to count vectors in Milvus: {e}") + sys.exit(1) diff --git a/demos/audio_searching/src/mysql_helpers.py b/demos/audio_searching/src/mysql_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..303838399be9c99c8c2d01ceff211d1fdd20b02f --- /dev/null +++ b/demos/audio_searching/src/mysql_helpers.py @@ -0,0 +1,133 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +import pymysql +from config import MYSQL_DB +from config import MYSQL_HOST +from config import MYSQL_PORT +from config import MYSQL_PWD +from config import MYSQL_USER +from logs import LOGGER + + +class MySQLHelper(): + """ + the basic operations of PyMySQL + + # This example shows how to: + # 1. connect to MySQL server + # 2. create a table + # 3. insert data to table + # 4. search by milvus ids + # 5. delete table + """ + + def __init__(self): + self.conn = pymysql.connect( + host=MYSQL_HOST, + user=MYSQL_USER, + port=MYSQL_PORT, + password=MYSQL_PWD, + database=MYSQL_DB, + local_infile=True) + self.cursor = self.conn.cursor() + + def test_connection(self): + try: + self.conn.ping() + except Exception: + self.conn = pymysql.connect( + host=MYSQL_HOST, + user=MYSQL_USER, + port=MYSQL_PORT, + password=MYSQL_PWD, + database=MYSQL_DB, + local_infile=True) + self.cursor = self.conn.cursor() + + def create_mysql_table(self, table_name): + # Create mysql table if not exists + self.test_connection() + sql = "create table if not exists " + table_name + "(milvus_id TEXT, audio_path TEXT);" + try: + self.cursor.execute(sql) + LOGGER.debug(f"MYSQL create table: {table_name} with sql: {sql}") + except Exception as e: + LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}") + sys.exit(1) + + def load_data_to_mysql(self, table_name, data): + # Batch insert (Milvus_ids, img_path) to mysql + self.test_connection() + sql = "insert into " + table_name + " (milvus_id,audio_path) values (%s,%s);" + try: + self.cursor.executemany(sql, data) + self.conn.commit() + LOGGER.debug( + f"MYSQL loads data to table: {table_name} successfully") + except Exception as e: + LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}") + sys.exit(1) + + def search_by_milvus_ids(self, ids, table_name): + # Get the img_path according to the milvus ids + self.test_connection() + str_ids = str(ids).replace('[', '').replace(']', '') + sql = "select audio_path from " + table_name + " where milvus_id in (" + str_ids + ") order by field (milvus_id," + str_ids + ");" + try: + self.cursor.execute(sql) + results = self.cursor.fetchall() + results = [res[0] for res in results] + LOGGER.debug("MYSQL search by milvus id.") + return results + except Exception as e: + LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}") + sys.exit(1) + + def delete_table(self, table_name): + # Delete mysql table if exists + self.test_connection() + sql = "drop table if exists " + table_name + ";" + try: + self.cursor.execute(sql) + LOGGER.debug(f"MYSQL delete table:{table_name}") + except Exception as e: + LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}") + sys.exit(1) + + def delete_all_data(self, table_name): + # Delete all the data in mysql table + self.test_connection() + sql = 'delete from ' + table_name + ';' + try: + self.cursor.execute(sql) + self.conn.commit() + LOGGER.debug(f"MYSQL delete all data in table:{table_name}") + except Exception as e: + LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}") + sys.exit(1) + + def count_table(self, table_name): + # Get the number of mysql table + self.test_connection() + sql = "select count(milvus_id) from " + table_name + ";" + try: + self.cursor.execute(sql) + results = self.cursor.fetchall() + LOGGER.debug(f"MYSQL count table:{table_name}") + return results[0][0] + except Exception as e: + LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}") + sys.exit(1) diff --git a/demos/audio_searching/src/operations/__init__.py b/demos/audio_searching/src/operations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97043fd7ba6885aac81cad5a49924c23c67d4d47 --- /dev/null +++ b/demos/audio_searching/src/operations/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/demos/audio_searching/src/operations/count.py b/demos/audio_searching/src/operations/count.py new file mode 100644 index 0000000000000000000000000000000000000000..9a1f4208213acfb7bea340940edd7a28ecb88d81 --- /dev/null +++ b/demos/audio_searching/src/operations/count.py @@ -0,0 +1,33 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +from config import DEFAULT_TABLE +from logs import LOGGER + + +def do_count(table_name, milvus_cli): + """ + Returns the total number of vectors in the system + """ + if not table_name: + table_name = DEFAULT_TABLE + try: + if not milvus_cli.has_collection(table_name): + return None + num = milvus_cli.count(table_name) + return num + except Exception as e: + LOGGER.error(f"Error attempting to count table {e}") + sys.exit(1) diff --git a/demos/audio_searching/src/operations/drop.py b/demos/audio_searching/src/operations/drop.py new file mode 100644 index 0000000000000000000000000000000000000000..f8278ddd04a56238c17bd7d8683415bb45aab785 --- /dev/null +++ b/demos/audio_searching/src/operations/drop.py @@ -0,0 +1,34 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +from config import DEFAULT_TABLE +from logs import LOGGER + + +def do_drop(table_name, milvus_cli, mysql_cli): + """ + Delete the collection of Milvus and MySQL + """ + if not table_name: + table_name = DEFAULT_TABLE + try: + if not milvus_cli.has_collection(table_name): + return "Collection is not exist" + status = milvus_cli.delete_collection(table_name) + mysql_cli.delete_table(table_name) + return status + except Exception as e: + LOGGER.error(f"Error attempting to drop table: {e}") + sys.exit(1) diff --git a/demos/audio_searching/src/operations/load.py b/demos/audio_searching/src/operations/load.py new file mode 100644 index 0000000000000000000000000000000000000000..792434fbe3a6b2ba44f785c5ce54648539f8084a --- /dev/null +++ b/demos/audio_searching/src/operations/load.py @@ -0,0 +1,86 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +from diskcache import Cache +from encode import get_audio_embedding + +from ..config import DEFAULT_TABLE +from ..logs import LOGGER + + +def get_audios(path): + """ + List all wav and aif files recursively under the path folder. + """ + supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"] + return [ + item + for sublist in [[os.path.join(dir, file) for file in files] + for dir, _, files in list(os.walk(path))] + for item in sublist if os.path.splitext(item)[1] in supported_formats + ] + + +def extract_features(audio_dir): + """ + Get the vector of audio + """ + try: + cache = Cache('./tmp') + feats = [] + names = [] + audio_list = get_audios(audio_dir) + total = len(audio_list) + cache['total'] = total + for i, audio_path in enumerate(audio_list): + norm_feat = get_audio_embedding(audio_path) + if norm_feat is None: + continue + feats.append(norm_feat) + names.append(audio_path.encode()) + cache['current'] = i + 1 + print( + f"Extracting feature from audio No. {i + 1} , {total} audios in total" + ) + return feats, names + except Exception as e: + LOGGER.error(f"Error with extracting feature from audio {e}") + sys.exit(1) + + +def format_data(ids, names): + """ + Combine the id of the vector and the name of the audio into a list + """ + data = [] + for i in range(len(ids)): + value = (str(ids[i]), names[i]) + data.append(value) + return data + + +def do_load(table_name, audio_dir, milvus_cli, mysql_cli): + """ + Import vectors to Milvus and data to Mysql respectively + """ + if not table_name: + table_name = DEFAULT_TABLE + vectors, names = extract_features(audio_dir) + ids = milvus_cli.insert(table_name, vectors) + milvus_cli.create_index(table_name) + mysql_cli.create_mysql_table(table_name) + mysql_cli.load_data_to_mysql(table_name, format_data(ids, names)) + return len(ids) diff --git a/demos/audio_searching/src/operations/search.py b/demos/audio_searching/src/operations/search.py new file mode 100644 index 0000000000000000000000000000000000000000..861fee01a448c10c567a951f5814ac34f0977a39 --- /dev/null +++ b/demos/audio_searching/src/operations/search.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +from config import DEFAULT_TABLE +from config import TOP_K +from encode import get_audio_embedding +from logs import LOGGER + + +def do_search(host, table_name, audio_path, milvus_cli, mysql_cli): + """ + Search the uploaded audio in Milvus/MySQL + """ + try: + if not table_name: + table_name = DEFAULT_TABLE + feat = get_audio_embedding(audio_path) + vectors = milvus_cli.search_vectors(table_name, [feat], TOP_K) + vids = [str(x.id) for x in vectors[0]] + paths = mysql_cli.search_by_milvus_ids(vids, table_name) + distances = [x.distance for x in vectors[0]] + for i in range(len(paths)): + tmp = "http://" + str(host) + "/data?audio_path=" + str(paths[i]) + paths[i] = tmp + return vids, paths, distances + except Exception as e: + LOGGER.error(f"Error with search: {e}") + sys.exit(1) diff --git a/demos/audio_searching/src/test_main.py b/demos/audio_searching/src/test_main.py new file mode 100644 index 0000000000000000000000000000000000000000..24405f38826ee6a471c8b38e8d7911578f7dbaf2 --- /dev/null +++ b/demos/audio_searching/src/test_main.py @@ -0,0 +1,96 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import zipfile + +import gdown +from fastapi.testclient import TestClient +from main import app + +client = TestClient(app) + + +def download_audio_data(): + """ + download audio data + """ + url = 'https://drive.google.com/uc?id=1bKu21JWBfcZBuEuzFEvPoAX6PmRrgnUp' + gdown.download(url) + + with zipfile.ZipFile('example_audio.zip', 'r') as zip_ref: + zip_ref.extractall('./example_audio') + + +def test_drop(): + """ + Delete the collection of Milvus and MySQL + """ + response = client.post("/audio/drop") + assert response.status_code == 200 + + +def test_load(): + """ + Insert all the audio files under the file path to Milvus/MySQL + """ + response = client.post("/audio/load", json={"File": "./example_audio"}) + assert response.status_code == 200 + assert response.json() == { + 'status': True, + 'msg': "Successfully loaded data!" + } + + +def test_progress(): + """ + Get the progress of dealing with data + """ + response = client.get("/progress") + assert response.status_code == 200 + assert response.json() == "current: 20, total: 20" + + +def test_count(): + """ + Returns the total number of vectors in the system + """ + response = client.get("audio/count") + assert response.status_code == 200 + assert response.json() == 20 + + +def test_search(): + """ + Search the uploaded audio in Milvus/MySQL + """ + response = client.post( + "/audio/search/local?query_audio_path=.%2Fexample_audio%2Ftest.wav") + assert response.status_code == 200 + assert len(response.json()) == 10 + + +def test_data(): + """ + Get the audio file + """ + response = client.get("/data?audio_path=.%2Fexample_audio%2Ftest.wav") + assert response.status_code == 200 + + +if __name__ == "__main__": + download_audio_data() + test_drop() + test_load() + test_count() + test_search() + test_drop()