paddlespeech_client.py 8.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import base64
import io
import json
import os
import random
import time
from typing import List

import numpy as np
import requests
import soundfile

L
lym0302 已提交
27
from ..executor import BaseExecutor
28
from ..util import cli_client_register
L
lym0302 已提交
29 30
from ..util import stats_wrapper
from paddlespeech.cli.log import logger
31 32 33 34 35 36 37 38
from paddlespeech.server.utils.audio_process import wav2pcm
from paddlespeech.server.utils.util import wav2base64

__all__ = ['TTSClientExecutor', 'ASRClientExecutor']


@cli_client_register(
    name='paddlespeech_client.tts', description='visit tts service')
L
lym0302 已提交
39
class TTSClientExecutor(BaseExecutor):
40
    def __init__(self):
L
lym0302 已提交
41 42 43
        super(TTSClientExecutor, self).__init__()
        self.parser = argparse.ArgumentParser(
            prog='paddlespeech_client.tts', add_help=True)
44 45 46 47 48
        self.parser.add_argument(
            '--server_ip', type=str, default='127.0.0.1', help='server ip')
        self.parser.add_argument(
            '--port', type=int, default=8090, help='server port')
        self.parser.add_argument(
L
lym0302 已提交
49
            '--input',
50
            type=str,
L
lym0302 已提交
51 52 53
            default=None,
            help='Text to be synthesized.',
            required=True)
54 55 56
        self.parser.add_argument(
            '--spk_id', type=int, default=0, help='Speaker id')
        self.parser.add_argument(
L
lym0302 已提交
57 58 59 60
            '--speed',
            type=float,
            default=1.0,
            help='Audio speed, the value should be set between 0 and 3')
61
        self.parser.add_argument(
L
lym0302 已提交
62 63 64 65
            '--volume',
            type=float,
            default=1.0,
            help='Audio volume, the value should be set between 0 and 3')
66 67 68 69
        self.parser.add_argument(
            '--sample_rate',
            type=int,
            default=0,
L
lym0302 已提交
70
            choices=[0, 8000, 16000],
71 72 73 74
            help='Sampling rate, the default is the same as the model')
        self.parser.add_argument(
            '--output',
            type=str,
L
lym0302 已提交
75
            default="./output.wav",
76 77
            help='Synthesized audio file')

L
lym0302 已提交
78
    def postprocess(self, response_dict: dict, outfile: str) -> float:
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
        wav_base64 = response_dict["result"]["audio"]
        audio_data_byte = base64.b64decode(wav_base64)
        # from byte
        samples, sample_rate = soundfile.read(
            io.BytesIO(audio_data_byte), dtype='float32')

        # transform audio
        if outfile.endswith(".wav"):
            soundfile.write(outfile, samples, sample_rate)
        elif outfile.endswith(".pcm"):
            temp_wav = str(random.getrandbits(128)) + ".wav"
            soundfile.write(temp_wav, samples, sample_rate)
            wav2pcm(temp_wav, outfile, data_type=np.int16)
            os.system("rm %s" % (temp_wav))
        else:
L
lym0302 已提交
94
            logger.error("The format for saving audio only supports wav or pcm")
95

L
lym0302 已提交
96 97
        duration = len(samples) / sample_rate
        return duration
98 99 100 101

    def execute(self, argv: List[str]) -> bool:
        args = self.parser.parse_args(argv)
        try:
L
lym0302 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
            url = 'http://' + args.server_ip + ":" + str(
                args.port) + '/paddlespeech/tts'
            request = {
                "text": args.input,
                "spk_id": args.spk_id,
                "speed": args.speed,
                "volume": args.volume,
                "sample_rate": args.sample_rate,
                "save_path": args.output
            }
            st = time.time()
            response = requests.post(url, json.dumps(request))
            time_consume = time.time() - st

            response_dict = response.json()
            duration = self.postprocess(response_dict, args.output)

            logger.info(response_dict["message"])
            logger.info("Save synthesized audio successfully on %s." %
                        (args.output))
            logger.info("Audio duration: %f s." % (duration))
            logger.info("Response time: %f s." % (time_consume))
            logger.info("RTF: %f " % (time_consume / duration))

            return True
L
lym0302 已提交
127
        except BaseException:
L
lym0302 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
            logger.error("Failed to synthesized audio.")
            return False

    @stats_wrapper
    def __call__(self,
                 input: str,
                 server_ip: str="127.0.0.1",
                 port: int=8090,
                 spk_id: int=0,
                 speed: float=1.0,
                 volume: float=1.0,
                 sample_rate: int=0,
                 output: str="./output.wav"):
        """
        Python API to call an executor.
        """

        url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/tts'
        request = {
            "text": input,
            "spk_id": spk_id,
            "speed": speed,
            "volume": volume,
            "sample_rate": sample_rate,
            "save_path": output
        }

        try:
            st = time.time()
            response = requests.post(url, json.dumps(request))
158
            time_consume = time.time() - st
L
lym0302 已提交
159 160 161 162 163 164 165 166
            response_dict = response.json()
            duration = self.postprocess(response_dict, output)

            print(response_dict["message"])
            print("Save synthesized audio successfully on %s." % (output))
            print("Audio duration: %f s." % (duration))
            print("Response time: %f s." % (time_consume))
            print("RTF: %f " % (time_consume / duration))
L
lym0302 已提交
167
        except BaseException:
168 169 170 171 172
            print("Failed to synthesized audio.")


@cli_client_register(
    name='paddlespeech_client.asr', description='visit asr service')
L
lym0302 已提交
173
class ASRClientExecutor(BaseExecutor):
174
    def __init__(self):
L
lym0302 已提交
175 176 177
        super(ASRClientExecutor, self).__init__()
        self.parser = argparse.ArgumentParser(
            prog='paddlespeech_client.asr', add_help=True)
178 179 180 181 182
        self.parser.add_argument(
            '--server_ip', type=str, default='127.0.0.1', help='server ip')
        self.parser.add_argument(
            '--port', type=int, default=8090, help='server port')
        self.parser.add_argument(
L
lym0302 已提交
183
            '--input',
184
            type=str,
L
lym0302 已提交
185 186 187
            default=None,
            help='Audio file to be recognized',
            required=True)
188 189
        self.parser.add_argument(
            '--sample_rate', type=int, default=16000, help='audio sample rate')
L
lym0302 已提交
190 191 192 193
        self.parser.add_argument(
            '--lang', type=str, default="zh_cn", help='language')
        self.parser.add_argument(
            '--audio_format', type=str, default="wav", help='audio format')
194 195 196 197 198

    def execute(self, argv: List[str]) -> bool:
        args = self.parser.parse_args(argv)
        url = 'http://' + args.server_ip + ":" + str(
            args.port) + '/paddlespeech/asr'
L
lym0302 已提交
199
        audio = wav2base64(args.input)
200 201
        data = {
            "audio": audio,
L
lym0302 已提交
202
            "audio_format": args.audio_format,
203
            "sample_rate": args.sample_rate,
L
lym0302 已提交
204
            "lang": args.lang,
205 206
        }
        time_start = time.time()
L
lym0302 已提交
207 208 209 210 211 212 213
        try:
            r = requests.post(url=url, data=json.dumps(data))
            # ending Timestamp
            time_end = time.time()
            logger.info(r.json())
            logger.info("time cost %f s." % (time_end - time_start))
            return True
L
lym0302 已提交
214
        except BaseException:
L
lym0302 已提交
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
            logger.error("Failed to speech recognition.")
            return False

    @stats_wrapper
    def __call__(self,
                 input: str,
                 server_ip: str="127.0.0.1",
                 port: int=8090,
                 sample_rate: int=16000,
                 lang: str="zh_cn",
                 audio_format: str="wav"):
        """
        Python API to call an executor.
        """

        url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/asr'
        audio = wav2base64(input)
        data = {
            "audio": audio,
            "audio_format": audio_format,
            "sample_rate": sample_rate,
            "lang": lang,
        }
        time_start = time.time()
239 240 241 242
        try:
            r = requests.post(url=url, data=json.dumps(data))
            # ending Timestamp
            time_end = time.time()
L
lym0302 已提交
243
            print(r.json())
L
lym0302 已提交
244
            print("time cost %f s." % (time_end - time_start))
L
lym0302 已提交
245
        except BaseException:
L
lym0302 已提交
246
            print("Failed to speech recognition.")