audio_handler.py 20.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
lym0302 已提交
14
import base64
15 16
import json
import logging
L
lym0302 已提交
17 18
import threading
import time
19 20

import numpy as np
L
lym0302 已提交
21
import requests
22 23
import soundfile
import websockets
24

25
from paddlespeech.cli.log import logger
L
lym0302 已提交
26
from paddlespeech.server.utils.audio_process import save_audio
27 28 29 30 31
from paddlespeech.server.utils.util import wav2base64


class TextHttpHandler:
    def __init__(self, server_ip="127.0.0.1", port=8090):
32 33 34 35 36 37
        """Text http client request 

        Args:
            server_ip (str, optional): the text server ip. Defaults to "127.0.0.1".
            port (int, optional): the text server port. Defaults to 8090.
        """
38 39 40
        super().__init__()
        self.server_ip = server_ip
        self.port = port
41 42 43 44 45
        if server_ip is None or port is None:
            self.url = None
        else:
            self.url = 'http://' + self.server_ip + ":" + str(
                self.port) + '/paddlespeech/text'
46
        logger.info(f"endpoint: {self.url}")
47 48

    def run(self, text):
49 50 51 52 53 54 55 56
        """Call the text server to process the specific text

        Args:
            text (str): the text to be processed

        Returns:
            str: punctuation text
        """
57 58 59 60 61 62 63 64 65 66
        if self.server_ip is None or self.port is None:
            return text
        request = {
            "text": text,
        }
        try:
            res = requests.post(url=self.url, data=json.dumps(request))
            response_dict = res.json()
            punc_text = response_dict["result"]["punc_text"]
        except Exception as e:
67
            logger.error(f"Call punctuation {self.url} occurs error")
68 69 70 71
            logger.error(e)
            punc_text = text

        return punc_text
72 73


74
class ASRWsAudioHandler:
75
    def __init__(self,
76 77 78 79 80
                 url=None,
                 port=None,
                 endpoint="/paddlespeech/asr/streaming",
                 punc_server_ip=None,
                 punc_server_port=None):
81 82 83
        """PaddleSpeech Online ASR Server Client  audio handler
           Online asr server use the websocket protocal
        Args:
84 85 86
            url (str, optional): the server ip. Defaults to None.
            port (int, optional): the server port. Defaults to None.
            endpoint(str, optional): to compatiable with python server and c++ server.
87 88
            punc_server_ip(str, optional): the punctuation server ip. Defaults to None. 
            punc_server_port(int, optional): the punctuation port. Defaults to None
89
        """
90 91
        self.url = url
        self.port = port
92 93 94
        if url is None or port is None or endpoint is None:
            self.url = None
        else:
95
            self.url = "ws://" + self.url + ":" + str(self.port) + endpoint
96
        self.punc_server = TextHttpHandler(punc_server_ip, punc_server_port)
97
        logger.info(f"endpoint: {self.url}")
98

99
    def read_wave(self, wavfile_path: str):
100 101 102 103 104 105 106 107 108
        """read the audio file from specific wavfile path

        Args:
            wavfile_path (str): the audio wavfile, 
                                 we assume that audio sample rate matches the model

        Yields:
            numpy.array: the samall package audio pcm data
        """
109 110
        samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
        x_len = len(samples)
111 112 113
        assert sample_rate == 16000

        chunk_size = int(85 * sample_rate / 1000)  # 85ms, sample_rate = 16kHz
114

115
        if x_len % chunk_size != 0:
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
            padding_len_x = chunk_size - x_len % chunk_size
        else:
            padding_len_x = 0

        padding = np.zeros((padding_len_x), dtype=samples.dtype)
        padded_x = np.concatenate([samples, padding], axis=0)

        assert (x_len + padding_len_x) % chunk_size == 0
        num_chunk = (x_len + padding_len_x) / chunk_size
        num_chunk = int(num_chunk)
        for i in range(0, num_chunk):
            start = i * chunk_size
            end = start + chunk_size
            x_chunk = padded_x[start:end]
            yield x_chunk

    async def run(self, wavfile_path: str):
133 134 135 136 137 138 139 140
        """Send a audio file to online server

        Args:
            wavfile_path (str): audio path

        Returns:
            str: the final asr result
        """
141
        logging.info("send a message to the server")
142

143
        if self.url is None:
144
            logger.error("No asr server, please input valid ip and port")
145 146
            return ""

147
        # 1. send websocket handshake protocal
X
xiongxinlei 已提交
148
        start_time = time.time()
149
        async with websockets.connect(self.url) as ws:
150
            # 2. server has already received handshake protocal
151 152 153 154 155
            # client start to send the command
            audio_info = json.dumps(
                {
                    "name": "test.wav",
                    "signal": "start",
156
                    "nbest": 1
157 158 159 160 161 162
                },
                sort_keys=True,
                indent=4,
                separators=(',', ': '))
            await ws.send(audio_info)
            msg = await ws.recv()
163
            logger.info("client receive msg={}".format(msg))
164

165
            # 3. send chunk audio data to engine
166 167 168 169
            for chunk_data in self.read_wave(wavfile_path):
                await ws.send(chunk_data.tobytes())
                msg = await ws.recv()
                msg = json.loads(msg)
Honei_X's avatar
Honei_X 已提交
170

171
                if self.punc_server and len(msg["result"]) > 0:
172
                    msg["result"] = self.punc_server.run(msg["result"])
173
                logger.info("client receive msg={}".format(msg))
174

175
            # 4. we must send finished signal to the server
176 177 178 179
            audio_info = json.dumps(
                {
                    "name": "test.wav",
                    "signal": "end",
180
                    "nbest": 1
181 182 183 184 185 186
                },
                sort_keys=True,
                indent=4,
                separators=(',', ': '))
            await ws.send(audio_info)
            msg = await ws.recv()
187 188

            # 5. decode the bytes to str
189
            msg = json.loads(msg)
Honei_X's avatar
Honei_X 已提交
190

191
            if self.punc_server:
192
                msg["result"] = self.punc_server.run(msg["result"])
193

X
xiongxinlei 已提交
194 195 196
            # 6. logging the final result and comptute the statstics
            elapsed_time = time.time() - start_time
            audio_info = soundfile.info(wavfile_path)
197
            logger.info("client final receive msg={}".format(msg))
X
xiongxinlei 已提交
198 199 200 201
            logger.info(
                f"audio duration: {audio_info.duration}, elapsed time: {elapsed_time}, RTF={elapsed_time/audio_info.duration}"
            )

202
            result = msg
203

204
            return result
L
lym0302 已提交
205 206


207
class ASRHttpHandler:
208 209 210 211 212 213 214
    def __init__(self, server_ip=None, port=None):
        """The ASR client http request

        Args:
            server_ip (str, optional): the http asr server ip. Defaults to "127.0.0.1".
            port (int, optional): the http asr server port. Defaults to 8090.
        """
215 216 217
        super().__init__()
        self.server_ip = server_ip
        self.port = port
218 219 220 221 222
        if server_ip is None or port is None:
            self.url = None
        else:
            self.url = 'http://' + self.server_ip + ":" + str(
                self.port) + '/paddlespeech/asr'
223
        logger.info(f"endpoint: {self.url}")
224 225

    def run(self, input, audio_format, sample_rate, lang):
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
        """Call the http asr to process the audio

        Args:
            input (str): the audio file path
            audio_format (str): the audio format
            sample_rate (str): the audio sample rate
            lang (str): the audio language type

        Returns:
            str: the final asr result
        """
        if self.url is None:
            logger.error(
                "No punctuation server, please input valid ip and port")
            return ""

242 243 244 245 246 247 248 249 250 251 252 253 254
        audio = wav2base64(input)
        data = {
            "audio": audio,
            "audio_format": audio_format,
            "sample_rate": sample_rate,
            "lang": lang,
        }

        res = requests.post(url=self.url, data=json.dumps(data))

        return res.json()


L
lym0302 已提交
255 256 257 258 259 260 261 262 263 264 265
class TTSWsHandler:
    def __init__(self, server="127.0.0.1", port=8092, play: bool=False):
        """PaddleSpeech Online TTS Server Client  audio handler
           Online tts server use the websocket protocal
        Args:
            server (str, optional): the server ip. Defaults to "127.0.0.1".
            port (int, optional): the server port. Defaults to 8092.
            play (bool, optional): whether to play audio. Defaults False
        """
        self.server = server
        self.port = port
L
lym0302 已提交
266 267
        self.url = "ws://" + self.server + ":" + str(
            self.port) + "/paddlespeech/tts/streaming"
L
lym0302 已提交
268 269 270 271 272 273 274 275 276 277 278 279 280 281
        self.play = play
        if self.play:
            import pyaudio
            self.buffer = b''
            self.p = pyaudio.PyAudio()
            self.stream = self.p.open(
                format=self.p.get_format_from_width(2),
                channels=1,
                rate=24000,
                output=True)
            self.mutex = threading.Lock()
            self.start_play = True
            self.t = threading.Thread(target=self.play_audio)
            self.max_fail = 50
282
        logger.info(f"endpoint: {self.url}")
L
lym0302 已提交
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303

    def play_audio(self):
        while True:
            if not self.buffer:
                self.max_fail -= 1
                time.sleep(0.05)
                if self.max_fail < 0:
                    break
            self.mutex.acquire()
            self.stream.write(self.buffer)
            self.buffer = b''
            self.mutex.release()

    async def run(self, text: str, output: str=None):
        """Send a text to online server

        Args:
            text (str): sentence to be synthesized
            output (str): save audio path
        """
        all_bytes = b''
L
lym0302 已提交
304 305
        receive_time_list = []
        chunk_duration_list = []
L
lym0302 已提交
306

L
lym0302 已提交
307
        # 1. Send websocket handshake request
L
lym0302 已提交
308
        async with websockets.connect(self.url) as ws:
L
lym0302 已提交
309 310 311 312 313 314 315 316 317
            # 2. Server has already received handshake response, send start request
            start_request = json.dumps({"task": "tts", "signal": "start"})
            await ws.send(start_request)
            msg = await ws.recv()
            logger.info(f"client receive msg={msg}")
            msg = json.loads(msg)
            session = msg["session"]

            # 3. send speech synthesis request 
L
lym0302 已提交
318
            text_base64 = str(base64.b64encode((text).encode('utf-8')), "UTF8")
L
lym0302 已提交
319
            request = json.dumps({"text": text_base64})
L
lym0302 已提交
320
            st = time.time()
L
lym0302 已提交
321
            await ws.send(request)
L
lym0302 已提交
322 323
            logging.info("send a message to the server")

L
lym0302 已提交
324
            # 4. Process the received response
L
lym0302 已提交
325
            message = await ws.recv()
L
lym0302 已提交
326
            first_response = time.time() - st
L
lym0302 已提交
327 328
            message = json.loads(message)
            status = message["status"]
L
lym0302 已提交
329 330 331 332 333 334 335 336 337 338 339
            while True:
                # When throw an exception
                if status == -1:
                    # send end request
                    end_request = json.dumps({
                        "task": "tts",
                        "signal": "end",
                        "session": session
                    })
                    await ws.send(end_request)
                    break
L
lym0302 已提交
340

L
lym0302 已提交
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
                # Rerutn last packet normally, no audio information
                elif status == 2:
                    final_response = time.time() - st
                    duration = len(all_bytes) / 2.0 / 24000

                    if output is not None:
                        save_audio_success = save_audio(all_bytes, output)
                    else:
                        save_audio_success = False

                    # send end request
                    end_request = json.dumps({
                        "task": "tts",
                        "signal": "end",
                        "session": session
                    })
                    await ws.send(end_request)
                    break
L
lym0302 已提交
359

L
lym0302 已提交
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
                # Return the audio stream normally
                elif status == 1:
                    receive_time_list.append(time.time())
                    audio = message["audio"]
                    audio = base64.b64decode(audio)  # bytes
                    chunk_duration_list.append(len(audio) / 2.0 / 24000)
                    all_bytes += audio
                    if self.play:
                        self.mutex.acquire()
                        self.buffer += audio
                        self.mutex.release()
                        if self.start_play:
                            self.t.start()
                            self.start_play = False

                    message = await ws.recv()
                    message = json.loads(message)
                    status = message["status"]
L
lym0302 已提交
378

L
lym0302 已提交
379 380
                else:
                    logger.error("infer error, return status is invalid.")
L
lym0302 已提交
381 382 383 384 385 386 387

            if self.play:
                self.t.join()
                self.stream.stop_stream()
                self.stream.close()
                self.p.terminate()

L
lym0302 已提交
388 389
        return first_response, final_response, duration, save_audio_success, receive_time_list, chunk_duration_list

L
lym0302 已提交
390 391 392 393 394 395 396 397 398 399 400 401 402

class TTSHttpHandler:
    def __init__(self, server="127.0.0.1", port=8092, play: bool=False):
        """PaddleSpeech Online TTS Server Client  audio handler
           Online tts server use the websocket protocal
        Args:
            server (str, optional): the server ip. Defaults to "127.0.0.1".
            port (int, optional): the server port. Defaults to 8092.
            play (bool, optional): whether to play audio. Defaults False
        """
        self.server = server
        self.port = port
        self.url = "http://" + str(self.server) + ":" + str(
L
lym0302 已提交
403
            self.port) + "/paddlespeech/tts/streaming"
L
lym0302 已提交
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
        self.play = play

        if self.play:
            import pyaudio
            self.buffer = b''
            self.p = pyaudio.PyAudio()
            self.stream = self.p.open(
                format=self.p.get_format_from_width(2),
                channels=1,
                rate=24000,
                output=True)
            self.mutex = threading.Lock()
            self.start_play = True
            self.t = threading.Thread(target=self.play_audio)
            self.max_fail = 50
419
        logger.info(f"endpoint: {self.url}")
L
lym0302 已提交
420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461

    def play_audio(self):
        while True:
            if not self.buffer:
                self.max_fail -= 1
                time.sleep(0.05)
                if self.max_fail < 0:
                    break
            self.mutex.acquire()
            self.stream.write(self.buffer)
            self.buffer = b''
            self.mutex.release()

    def run(self,
            text: str,
            spk_id=0,
            speed=1.0,
            volume=1.0,
            sample_rate=0,
            output: str=None):
        """Send a text to tts online server

        Args:
            text (str): sentence to be synthesized.
            spk_id (int, optional): speaker id. Defaults to 0.
            speed (float, optional): audio speed. Defaults to 1.0.
            volume (float, optional): audio volume. Defaults to 1.0.
            sample_rate (int, optional): audio sample rate, 0 means the same as model. Defaults to 0.
            output (str, optional): save audio path. Defaults to None.
        """
        # 1. Create request
        params = {
            "text": text,
            "spk_id": spk_id,
            "speed": speed,
            "volume": volume,
            "sample_rate": sample_rate,
            "save_path": output
        }

        all_bytes = b''
        first_flag = 1
L
lym0302 已提交
462 463
        receive_time_list = []
        chunk_duration_list = []
L
lym0302 已提交
464 465 466 467 468 469

        # 2. Send request
        st = time.time()
        html = requests.post(self.url, json.dumps(params), stream=True)

        # 3. Process the received response 
L
lym0302 已提交
470 471
        for chunk in html.iter_content(chunk_size=None):
            receive_time_list.append(time.time())
L
lym0302 已提交
472 473 474 475 476 477 478 479 480 481 482 483 484
            audio = base64.b64decode(chunk)  # bytes
            if first_flag:
                first_response = time.time() - st
                first_flag = 0

            if self.play:
                self.mutex.acquire()
                self.buffer += audio
                self.mutex.release()
                if self.start_play:
                    self.t.start()
                    self.start_play = False
            all_bytes += audio
L
lym0302 已提交
485
            chunk_duration_list.append(len(audio) / 2.0 / 24000)
L
lym0302 已提交
486 487 488

        final_response = time.time() - st
        duration = len(all_bytes) / 2.0 / 24000
L
lym0302 已提交
489
        html.close()  # when stream=True
L
lym0302 已提交
490 491

        if output is not None:
L
lym0302 已提交
492 493 494
            save_audio_success = save_audio(all_bytes, output)
        else:
            save_audio_success = False
L
lym0302 已提交
495 496 497 498 499 500

        if self.play:
            self.t.join()
            self.stream.stop_stream()
            self.stream.close()
            self.p.terminate()
X
xiongxinlei 已提交
501

L
lym0302 已提交
502 503
        return first_response, final_response, duration, save_audio_success, receive_time_list, chunk_duration_list

X
xiongxinlei 已提交
504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520

class VectorHttpHandler:
    def __init__(self, server_ip=None, port=None):
        """The Vector client http request

        Args:
            server_ip (str, optional): the http vector server ip. Defaults to "127.0.0.1".
            port (int, optional): the http vector server port. Defaults to 8090.
        """
        super().__init__()
        self.server_ip = server_ip
        self.port = port
        if server_ip is None or port is None:
            self.url = None
        else:
            self.url = 'http://' + self.server_ip + ":" + str(
                self.port) + '/paddlespeech/vector'
521
        logger.info(f"endpoint: {self.url}")
X
xiongxinlei 已提交
522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567

    def run(self, input, audio_format, sample_rate, task="spk"):
        """Call the http asr to process the audio

        Args:
            input (str): the audio file path
            audio_format (str): the audio format
            sample_rate (str): the audio sample rate

        Returns:
            list: the audio vector
        """
        if self.url is None:
            logger.error("No vector server, please input valid ip and port")
            return ""

        audio = wav2base64(input)
        data = {
            "audio": audio,
            "task": task,
            "audio_format": audio_format,
            "sample_rate": sample_rate,
        }

        logger.info(self.url)
        res = requests.post(url=self.url, data=json.dumps(data))

        return res.json()


class VectorScoreHttpHandler:
    def __init__(self, server_ip=None, port=None):
        """The Vector score client http request

        Args:
            server_ip (str, optional): the http vector server ip. Defaults to "127.0.0.1".
            port (int, optional): the http vector server port. Defaults to 8090.
        """
        super().__init__()
        self.server_ip = server_ip
        self.port = port
        if server_ip is None or port is None:
            self.url = None
        else:
            self.url = 'http://' + self.server_ip + ":" + str(
                self.port) + '/paddlespeech/vector/score'
568
        logger.info(f"endpoint: {self.url}")
X
xiongxinlei 已提交
569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597

    def run(self, enroll_audio, test_audio, audio_format, sample_rate):
        """Call the http asr to process the audio

        Args:
            input (str): the audio file path
            audio_format (str): the audio format
            sample_rate (str): the audio sample rate

        Returns:
            list: the audio vector
        """
        if self.url is None:
            logger.error("No vector server, please input valid ip and port")
            return ""

        enroll_audio = wav2base64(enroll_audio)
        test_audio = wav2base64(test_audio)
        data = {
            "enroll_audio": enroll_audio,
            "test_audio": test_audio,
            "task": "score",
            "audio_format": audio_format,
            "sample_rate": sample_rate,
        }

        res = requests.post(url=self.url, data=json.dumps(data))

        return res.json()