paddlespeech_client.py 12.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
H
format  
Hui Zhang 已提交
15
import asyncio
16 17 18
import base64
import io
import json
H
format  
Hui Zhang 已提交
19
import logging
20 21 22 23 24 25 26 27 28
import os
import random
import time
from typing import List

import numpy as np
import requests
import soundfile

L
lym0302 已提交
29
from ..executor import BaseExecutor
30
from ..util import cli_client_register
L
lym0302 已提交
31 32
from ..util import stats_wrapper
from paddlespeech.cli.log import logger
H
format  
Hui Zhang 已提交
33
from paddlespeech.server.tests.asr.online.websocket_client import ASRAudioHandler
34 35 36
from paddlespeech.server.utils.audio_process import wav2pcm
from paddlespeech.server.utils.util import wav2base64

L
lym0302 已提交
37
__all__ = ['TTSClientExecutor', 'ASRClientExecutor', 'CLSClientExecutor']
38 39 40 41


@cli_client_register(
    name='paddlespeech_client.tts', description='visit tts service')
L
lym0302 已提交
42
class TTSClientExecutor(BaseExecutor):
43
    def __init__(self):
L
lym0302 已提交
44 45 46
        super(TTSClientExecutor, self).__init__()
        self.parser = argparse.ArgumentParser(
            prog='paddlespeech_client.tts', add_help=True)
47 48 49 50 51
        self.parser.add_argument(
            '--server_ip', type=str, default='127.0.0.1', help='server ip')
        self.parser.add_argument(
            '--port', type=int, default=8090, help='server port')
        self.parser.add_argument(
L
lym0302 已提交
52
            '--input',
53
            type=str,
L
lym0302 已提交
54 55 56
            default=None,
            help='Text to be synthesized.',
            required=True)
57 58 59
        self.parser.add_argument(
            '--spk_id', type=int, default=0, help='Speaker id')
        self.parser.add_argument(
L
lym0302 已提交
60 61 62 63
            '--speed',
            type=float,
            default=1.0,
            help='Audio speed, the value should be set between 0 and 3')
64
        self.parser.add_argument(
L
lym0302 已提交
65 66 67 68
            '--volume',
            type=float,
            default=1.0,
            help='Audio volume, the value should be set between 0 and 3')
69 70 71 72
        self.parser.add_argument(
            '--sample_rate',
            type=int,
            default=0,
L
lym0302 已提交
73
            choices=[0, 8000, 16000],
74 75
            help='Sampling rate, the default is the same as the model')
        self.parser.add_argument(
L
lym0302 已提交
76
            '--output', type=str, default=None, help='Synthesized audio file')
77

L
lym0302 已提交
78
    def postprocess(self, wav_base64: str, outfile: str) -> float:
79 80 81 82 83 84 85 86 87 88 89 90 91 92
        audio_data_byte = base64.b64decode(wav_base64)
        # from byte
        samples, sample_rate = soundfile.read(
            io.BytesIO(audio_data_byte), dtype='float32')

        # transform audio
        if outfile.endswith(".wav"):
            soundfile.write(outfile, samples, sample_rate)
        elif outfile.endswith(".pcm"):
            temp_wav = str(random.getrandbits(128)) + ".wav"
            soundfile.write(temp_wav, samples, sample_rate)
            wav2pcm(temp_wav, outfile, data_type=np.int16)
            os.system("rm %s" % (temp_wav))
        else:
L
lym0302 已提交
93
            logger.error("The format for saving audio only supports wav or pcm")
94 95 96

    def execute(self, argv: List[str]) -> bool:
        args = self.parser.parse_args(argv)
L
lym0302 已提交
97 98 99 100 101 102 103 104
        input_ = args.input
        server_ip = args.server_ip
        port = args.port
        spk_id = args.spk_id
        speed = args.speed
        volume = args.volume
        sample_rate = args.sample_rate
        output = args.output
L
lym0302 已提交
105

L
lym0302 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119
        try:
            time_start = time.time()
            res = self(
                input=input_,
                server_ip=server_ip,
                port=port,
                spk_id=spk_id,
                speed=speed,
                volume=volume,
                sample_rate=sample_rate,
                output=output)
            time_end = time.time()
            time_consume = time_end - time_start
            response_dict = res.json()
L
lym0302 已提交
120
            logger.info(response_dict["message"])
L
lym0302 已提交
121 122 123
            logger.info("Save synthesized audio successfully on %s." % (output))
            logger.info("Audio duration: %f s." %
                        (response_dict['result']['duration']))
L
lym0302 已提交
124 125
            logger.info("Response time: %f s." % (time_consume))
            return True
L
lym0302 已提交
126
        except Exception as e:
L
lym0302 已提交
127 128 129 130 131 132 133 134 135 136 137 138
            logger.error("Failed to synthesized audio.")
            return False

    @stats_wrapper
    def __call__(self,
                 input: str,
                 server_ip: str="127.0.0.1",
                 port: int=8090,
                 spk_id: int=0,
                 speed: float=1.0,
                 volume: float=1.0,
                 sample_rate: int=0,
L
lym0302 已提交
139
                 output: str=None):
L
lym0302 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153
        """
        Python API to call an executor.
        """

        url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/tts'
        request = {
            "text": input,
            "spk_id": spk_id,
            "speed": speed,
            "volume": volume,
            "sample_rate": sample_rate,
            "save_path": output
        }

L
lym0302 已提交
154 155
        res = requests.post(url, json.dumps(request))
        response_dict = res.json()
L
liangym 已提交
156
        if output is not None:
L
lym0302 已提交
157 158
            self.postprocess(response_dict["result"]["audio"], output)
        return res
159 160 161 162


@cli_client_register(
    name='paddlespeech_client.asr', description='visit asr service')
L
lym0302 已提交
163
class ASRClientExecutor(BaseExecutor):
164
    def __init__(self):
L
lym0302 已提交
165 166 167
        super(ASRClientExecutor, self).__init__()
        self.parser = argparse.ArgumentParser(
            prog='paddlespeech_client.asr', add_help=True)
168 169 170 171 172
        self.parser.add_argument(
            '--server_ip', type=str, default='127.0.0.1', help='server ip')
        self.parser.add_argument(
            '--port', type=int, default=8090, help='server port')
        self.parser.add_argument(
L
lym0302 已提交
173
            '--input',
174
            type=str,
L
lym0302 已提交
175 176 177
            default=None,
            help='Audio file to be recognized',
            required=True)
178 179
        self.parser.add_argument(
            '--sample_rate', type=int, default=16000, help='audio sample rate')
L
lym0302 已提交
180 181 182 183
        self.parser.add_argument(
            '--lang', type=str, default="zh_cn", help='language')
        self.parser.add_argument(
            '--audio_format', type=str, default="wav", help='audio format')
184 185 186

    def execute(self, argv: List[str]) -> bool:
        args = self.parser.parse_args(argv)
L
lym0302 已提交
187 188 189 190 191 192 193
        input_ = args.input
        server_ip = args.server_ip
        port = args.port
        sample_rate = args.sample_rate
        lang = args.lang
        audio_format = args.audio_format

L
lym0302 已提交
194
        try:
L
lym0302 已提交
195 196 197 198 199 200 201 202
            time_start = time.time()
            res = self(
                input=input_,
                server_ip=server_ip,
                port=port,
                sample_rate=sample_rate,
                lang=lang,
                audio_format=audio_format)
L
lym0302 已提交
203
            time_end = time.time()
L
lym0302 已提交
204 205
            logger.info(res.json())
            logger.info("Response time %f s." % (time_end - time_start))
L
lym0302 已提交
206
            return True
L
lym0302 已提交
207
        except Exception as e:
L
lym0302 已提交
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
            logger.error("Failed to speech recognition.")
            return False

    @stats_wrapper
    def __call__(self,
                 input: str,
                 server_ip: str="127.0.0.1",
                 port: int=8090,
                 sample_rate: int=16000,
                 lang: str="zh_cn",
                 audio_format: str="wav"):
        """
        Python API to call an executor.
        """

        url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/asr'
        audio = wav2base64(input)
        data = {
            "audio": audio,
            "audio_format": audio_format,
            "sample_rate": sample_rate,
            "lang": lang,
        }
L
lym0302 已提交
231 232 233

        res = requests.post(url=url, data=json.dumps(data))
        return res
L
lym0302 已提交
234 235


236
@cli_client_register(
H
format  
Hui Zhang 已提交
237 238
    name='paddlespeech_client.asr_online',
    description='visit asr online service')
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
class ASRClientExecutor(BaseExecutor):
    def __init__(self):
        super(ASRClientExecutor, self).__init__()
        self.parser = argparse.ArgumentParser(
            prog='paddlespeech_client.asr', add_help=True)
        self.parser.add_argument(
            '--server_ip', type=str, default='127.0.0.1', help='server ip')
        self.parser.add_argument(
            '--port', type=int, default=8091, help='server port')
        self.parser.add_argument(
            '--input',
            type=str,
            default=None,
            help='Audio file to be recognized',
            required=True)
        self.parser.add_argument(
            '--sample_rate', type=int, default=16000, help='audio sample rate')
        self.parser.add_argument(
            '--lang', type=str, default="zh_cn", help='language')
        self.parser.add_argument(
            '--audio_format', type=str, default="wav", help='audio format')

    def execute(self, argv: List[str]) -> bool:
        args = self.parser.parse_args(argv)
        input_ = args.input
        server_ip = args.server_ip
        port = args.port
        sample_rate = args.sample_rate
        lang = args.lang
        audio_format = args.audio_format

        try:
            time_start = time.time()
            res = self(
                input=input_,
                server_ip=server_ip,
                port=port,
                sample_rate=sample_rate,
                lang=lang,
                audio_format=audio_format)
            time_end = time.time()
280
            logger.info(res)
281 282 283 284
            logger.info("Response time %f s." % (time_end - time_start))
            return True
        except Exception as e:
            logger.error("Failed to speech recognition.")
285
            logger.error(e)
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
            return False

    @stats_wrapper
    def __call__(self,
                 input: str,
                 server_ip: str="127.0.0.1",
                 port: int=8091,
                 sample_rate: int=16000,
                 lang: str="zh_cn",
                 audio_format: str="wav"):
        """
        Python API to call an executor.
        """
        logging.basicConfig(level=logging.INFO)
        logging.info("asr websocket client start")
        handler = ASRAudioHandler(server_ip, port)
        loop = asyncio.get_event_loop()
303
        res = loop.run_until_complete(handler.run(input))
304 305
        logging.info("asr websocket client finished")

306
        return res['asr_results']
307

L
lym0302 已提交
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
@cli_client_register(
    name='paddlespeech_client.cls', description='visit cls service')
class CLSClientExecutor(BaseExecutor):
    def __init__(self):
        super(CLSClientExecutor, self).__init__()
        self.parser = argparse.ArgumentParser(
            prog='paddlespeech_client.cls', add_help=True)
        self.parser.add_argument(
            '--server_ip', type=str, default='127.0.0.1', help='server ip')
        self.parser.add_argument(
            '--port', type=int, default=8090, help='server port')
        self.parser.add_argument(
            '--input',
            type=str,
            default=None,
            help='Audio file to classify.',
            required=True)
        self.parser.add_argument(
            '--topk',
            type=int,
            default=1,
            help='Return topk scores of classification result.')

    def execute(self, argv: List[str]) -> bool:
        args = self.parser.parse_args(argv)
L
lym0302 已提交
333 334 335 336 337
        input_ = args.input
        server_ip = args.server_ip
        port = args.port
        topk = args.topk

L
lym0302 已提交
338
        try:
L
lym0302 已提交
339 340
            time_start = time.time()
            res = self(input=input_, server_ip=server_ip, port=port, topk=topk)
L
lym0302 已提交
341
            time_end = time.time()
L
lym0302 已提交
342
            logger.info(res.json())
L
lym0302 已提交
343 344
            logger.info("Response time %f s." % (time_end - time_start))
            return True
L
lym0302 已提交
345
        except Exception as e:
L
lym0302 已提交
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
            logger.error("Failed to speech classification.")
            return False

    @stats_wrapper
    def __call__(self,
                 input: str,
                 server_ip: str="127.0.0.1",
                 port: int=8090,
                 topk: int=1):
        """
        Python API to call an executor.
        """

        url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/cls'
        audio = wav2base64(input)
        data = {"audio": audio, "topk": topk}
L
lym0302 已提交
362 363 364

        res = requests.post(url=url, data=json.dumps(data))
        return res