Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
3151637a
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3151637a
编写于
2月 21, 2022
作者:
H
Hui Zhang
提交者:
GitHub
2月 21, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1460 from KPatr1ck/cli_batch
[CLI][Batch]Support batch input in cli.
上级
a8c3f6d4
7814fba0
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
277 addition
and
68 deletion
+277
-68
paddlespeech/cli/asr/infer.py
paddlespeech/cli/asr/infer.py
+28
-9
paddlespeech/cli/cls/infer.py
paddlespeech/cli/cls/infer.py
+30
-11
paddlespeech/cli/executor.py
paddlespeech/cli/executor.py
+111
-2
paddlespeech/cli/st/infer.py
paddlespeech/cli/st/infer.py
+28
-9
paddlespeech/cli/text/infer.py
paddlespeech/cli/text/infer.py
+28
-9
paddlespeech/cli/tts/infer.py
paddlespeech/cli/tts/infer.py
+52
-28
未找到文件。
paddlespeech/cli/asr/infer.py
浏览文件 @
3151637a
...
@@ -12,8 +12,10 @@
...
@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
argparse
import
argparse
import
ast
import
os
import
os
import
sys
import
sys
from
collections
import
OrderedDict
from
typing
import
List
from
typing
import
List
from
typing
import
Optional
from
typing
import
Optional
from
typing
import
Union
from
typing
import
Union
...
@@ -130,7 +132,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -130,7 +132,7 @@ class ASRExecutor(BaseExecutor):
self
.
parser
=
argparse
.
ArgumentParser
(
self
.
parser
=
argparse
.
ArgumentParser
(
prog
=
'paddlespeech.asr'
,
add_help
=
True
)
prog
=
'paddlespeech.asr'
,
add_help
=
True
)
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
'--input'
,
type
=
str
,
required
=
Tru
e
,
help
=
'Audio file to recognize.'
)
'--input'
,
type
=
str
,
default
=
Non
e
,
help
=
'Audio file to recognize.'
)
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
'--model'
,
'--model'
,
type
=
str
,
type
=
str
,
...
@@ -180,6 +182,11 @@ class ASRExecutor(BaseExecutor):
...
@@ -180,6 +182,11 @@ class ASRExecutor(BaseExecutor):
type
=
str
,
type
=
str
,
default
=
paddle
.
get_device
(),
default
=
paddle
.
get_device
(),
help
=
'Choose device to execute model inference.'
)
help
=
'Choose device to execute model inference.'
)
self
.
parser
.
add_argument
(
'--job_dump_result'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
'Save job result into file.'
)
def
_get_pretrained_path
(
self
,
tag
:
str
)
->
os
.
PathLike
:
def
_get_pretrained_path
(
self
,
tag
:
str
)
->
os
.
PathLike
:
"""
"""
...
@@ -469,19 +476,31 @@ class ASRExecutor(BaseExecutor):
...
@@ -469,19 +476,31 @@ class ASRExecutor(BaseExecutor):
sample_rate
=
parser_args
.
sample_rate
sample_rate
=
parser_args
.
sample_rate
config
=
parser_args
.
config
config
=
parser_args
.
config
ckpt_path
=
parser_args
.
ckpt_path
ckpt_path
=
parser_args
.
ckpt_path
audio_file
=
parser_args
.
input
decode_method
=
parser_args
.
decode_method
decode_method
=
parser_args
.
decode_method
force_yes
=
parser_args
.
yes
force_yes
=
parser_args
.
yes
device
=
parser_args
.
device
device
=
parser_args
.
device
job_dump_result
=
parser_args
.
job_dump_result
try
:
task_source
=
self
.
get_task_source
(
parser_args
.
input
)
res
=
self
(
audio_file
,
model
,
lang
,
sample_rate
,
config
,
ckpt_path
,
task_results
=
OrderedDict
()
decode_method
,
force_yes
,
device
)
has_exceptions
=
False
logger
.
info
(
'ASR Result: {}'
.
format
(
res
))
return
True
for
id_
,
input_
in
task_source
.
items
():
except
Exception
as
e
:
try
:
logger
.
exception
(
e
)
res
=
self
(
input_
,
model
,
lang
,
sample_rate
,
config
,
ckpt_path
,
decode_method
,
force_yes
,
device
)
task_results
[
id_
]
=
res
except
Exception
as
e
:
has_exceptions
=
True
task_results
[
id_
]
=
f
'
{
e
.
__class__
.
__name__
}
:
{
e
}
'
self
.
process_task_results
(
parser_args
.
input
,
task_results
,
job_dump_result
)
if
has_exceptions
:
return
False
return
False
else
:
return
True
@
stats_wrapper
@
stats_wrapper
def
__call__
(
self
,
def
__call__
(
self
,
...
...
paddlespeech/cli/cls/infer.py
浏览文件 @
3151637a
...
@@ -12,7 +12,9 @@
...
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
argparse
import
argparse
import
ast
import
os
import
os
from
collections
import
OrderedDict
from
typing
import
List
from
typing
import
List
from
typing
import
Optional
from
typing
import
Optional
from
typing
import
Union
from
typing
import
Union
...
@@ -77,7 +79,7 @@ class CLSExecutor(BaseExecutor):
...
@@ -77,7 +79,7 @@ class CLSExecutor(BaseExecutor):
self
.
parser
=
argparse
.
ArgumentParser
(
self
.
parser
=
argparse
.
ArgumentParser
(
prog
=
'paddlespeech.cls'
,
add_help
=
True
)
prog
=
'paddlespeech.cls'
,
add_help
=
True
)
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
'--input'
,
type
=
str
,
required
=
Tru
e
,
help
=
'Audio file to classify.'
)
'--input'
,
type
=
str
,
default
=
Non
e
,
help
=
'Audio file to classify.'
)
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
'--model'
,
'--model'
,
type
=
str
,
type
=
str
,
...
@@ -109,6 +111,11 @@ class CLSExecutor(BaseExecutor):
...
@@ -109,6 +111,11 @@ class CLSExecutor(BaseExecutor):
type
=
str
,
type
=
str
,
default
=
paddle
.
get_device
(),
default
=
paddle
.
get_device
(),
help
=
'Choose device to execute model inference.'
)
help
=
'Choose device to execute model inference.'
)
self
.
parser
.
add_argument
(
'--job_dump_result'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
'Save job result into file.'
)
def
_get_pretrained_path
(
self
,
tag
:
str
)
->
os
.
PathLike
:
def
_get_pretrained_path
(
self
,
tag
:
str
)
->
os
.
PathLike
:
"""
"""
...
@@ -214,7 +221,7 @@ class CLSExecutor(BaseExecutor):
...
@@ -214,7 +221,7 @@ class CLSExecutor(BaseExecutor):
ret
=
''
ret
=
''
for
idx
in
topk_idx
:
for
idx
in
topk_idx
:
label
,
score
=
self
.
_label_list
[
idx
],
result
[
idx
]
label
,
score
=
self
.
_label_list
[
idx
],
result
[
idx
]
ret
+=
f
'
{
label
}
:
{
score
}
\n
'
ret
+=
f
'
{
label
}
{
score
}
'
return
ret
return
ret
def
postprocess
(
self
,
topk
:
int
)
->
Union
[
str
,
os
.
PathLike
]:
def
postprocess
(
self
,
topk
:
int
)
->
Union
[
str
,
os
.
PathLike
]:
...
@@ -234,18 +241,30 @@ class CLSExecutor(BaseExecutor):
...
@@ -234,18 +241,30 @@ class CLSExecutor(BaseExecutor):
label_file
=
parser_args
.
label_file
label_file
=
parser_args
.
label_file
cfg_path
=
parser_args
.
config
cfg_path
=
parser_args
.
config
ckpt_path
=
parser_args
.
ckpt_path
ckpt_path
=
parser_args
.
ckpt_path
audio_file
=
parser_args
.
input
topk
=
parser_args
.
topk
topk
=
parser_args
.
topk
device
=
parser_args
.
device
device
=
parser_args
.
device
job_dump_result
=
parser_args
.
job_dump_result
try
:
task_source
=
self
.
get_task_source
(
parser_args
.
input
)
res
=
self
(
audio_file
,
model_type
,
cfg_path
,
ckpt_path
,
label_file
,
task_results
=
OrderedDict
()
topk
,
device
)
has_exceptions
=
False
logger
.
info
(
'CLS Result:
\n
{}'
.
format
(
res
))
return
True
for
id_
,
input_
in
task_source
.
items
():
except
Exception
as
e
:
try
:
logger
.
exception
(
e
)
res
=
self
(
input_
,
model_type
,
cfg_path
,
ckpt_path
,
label_file
,
topk
,
device
)
task_results
[
id_
]
=
res
except
Exception
as
e
:
has_exceptions
=
True
task_results
[
id_
]
=
f
'
{
e
.
__class__
.
__name__
}
:
{
e
}
'
self
.
process_task_results
(
parser_args
.
input
,
task_results
,
job_dump_result
)
if
has_exceptions
:
return
False
return
False
else
:
return
True
@
stats_wrapper
@
stats_wrapper
def
__call__
(
self
,
def
__call__
(
self
,
...
@@ -259,7 +278,7 @@ class CLSExecutor(BaseExecutor):
...
@@ -259,7 +278,7 @@ class CLSExecutor(BaseExecutor):
"""
"""
Python API to call an executor.
Python API to call an executor.
"""
"""
audio_file
=
os
.
path
.
abspath
(
audio_file
)
audio_file
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
audio_file
)
)
paddle
.
set_device
(
device
)
paddle
.
set_device
(
device
)
self
.
_init_from_path
(
model
,
config
,
ckpt_path
,
label_file
)
self
.
_init_from_path
(
model
,
config
,
ckpt_path
,
label_file
)
self
.
preprocess
(
audio_file
)
self
.
preprocess
(
audio_file
)
...
...
paddlespeech/cli/executor.py
浏览文件 @
3151637a
...
@@ -12,14 +12,19 @@
...
@@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
os
import
sys
from
abc
import
ABC
from
abc
import
ABC
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
collections
import
OrderedDict
from
typing
import
Any
from
typing
import
Any
from
typing
import
Dict
from
typing
import
List
from
typing
import
List
from
typing
import
Union
from
typing
import
Union
import
paddle
import
paddle
from
.log
import
logger
class
BaseExecutor
(
ABC
):
class
BaseExecutor
(
ABC
):
"""
"""
...
@@ -27,8 +32,8 @@ class BaseExecutor(ABC):
...
@@ -27,8 +32,8 @@ class BaseExecutor(ABC):
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_inputs
=
d
ict
()
self
.
_inputs
=
OrderedD
ict
()
self
.
_outputs
=
d
ict
()
self
.
_outputs
=
OrderedD
ict
()
@
abstractmethod
@
abstractmethod
def
_get_pretrained_path
(
self
,
tag
:
str
)
->
os
.
PathLike
:
def
_get_pretrained_path
(
self
,
tag
:
str
)
->
os
.
PathLike
:
...
@@ -100,3 +105,107 @@ class BaseExecutor(ABC):
...
@@ -100,3 +105,107 @@ class BaseExecutor(ABC):
Python API to call an executor.
Python API to call an executor.
"""
"""
pass
pass
def
get_task_source
(
self
,
input_
:
Union
[
str
,
os
.
PathLike
,
None
]
)
->
Dict
[
str
,
Union
[
str
,
os
.
PathLike
]]:
"""
Get task input source from command line input.
Args:
input_ (Union[str, os.PathLike, None]): Input from command line.
Returns:
Dict[str, Union[str, os.PathLike]]: A dict with ids and inputs.
"""
if
self
.
_is_job_input
(
input_
):
ret
=
self
.
_get_job_contents
(
input_
)
else
:
ret
=
OrderedDict
()
if
input_
is
None
:
# Take input from stdin
for
i
,
line
in
enumerate
(
sys
.
stdin
):
line
=
line
.
strip
()
if
len
(
line
.
split
(
' '
))
==
1
:
ret
[
str
(
i
+
1
)]
=
line
elif
len
(
line
.
split
(
' '
))
==
2
:
id_
,
info
=
line
.
split
(
' '
)
ret
[
id_
]
=
info
else
:
# No valid input info from one line.
continue
else
:
ret
[
1
]
=
input_
return
ret
def
process_task_results
(
self
,
input_
:
Union
[
str
,
os
.
PathLike
,
None
],
results
:
Dict
[
str
,
os
.
PathLike
],
job_dump_result
:
bool
=
False
):
"""
Handling task results and redirect stdout if needed.
Args:
input_ (Union[str, os.PathLike, None]): Input from command line.
results (Dict[str, os.PathLike]): Task outputs.
job_dump_result (bool, optional): if True, dumps job results into file. Defaults to False.
"""
raw_text
=
self
.
_format_task_results
(
results
)
print
(
raw_text
,
end
=
''
)
if
self
.
_is_job_input
(
input_
)
and
job_dump_result
:
try
:
job_output_file
=
os
.
path
.
abspath
(
input_
)
+
'.done'
sys
.
stdout
=
open
(
job_output_file
,
'w'
)
print
(
raw_text
,
end
=
''
)
logger
.
info
(
f
'Results had been saved to:
{
job_output_file
}
'
)
finally
:
sys
.
stdout
.
close
()
def
_is_job_input
(
self
,
input_
:
Union
[
str
,
os
.
PathLike
])
->
bool
:
"""
Check if current input file is a job input or not.
Args:
input_ (Union[str, os.PathLike]): Input file of current task.
Returns:
bool: return `True` for job input, `False` otherwise.
"""
return
input_
and
os
.
path
.
isfile
(
input_
)
and
input_
.
endswith
(
'.job'
)
def
_get_job_contents
(
self
,
job_input
:
os
.
PathLike
)
->
Dict
[
str
,
Union
[
str
,
os
.
PathLike
]]:
"""
Read a job input file and return its contents in a dictionary.
Args:
job_input (os.PathLike): The job input file.
Returns:
Dict[str, str]: Contents of job input.
"""
job_contents
=
OrderedDict
()
with
open
(
job_input
)
as
f
:
for
line
in
f
:
line
=
line
.
strip
()
if
not
line
:
continue
k
,
v
=
line
.
split
(
' '
)
job_contents
[
k
]
=
v
return
job_contents
def
_format_task_results
(
self
,
results
:
Dict
[
str
,
Union
[
str
,
os
.
PathLike
]])
->
str
:
"""
Convert task results to raw text.
Args:
results (Dict[str, str]): A dictionary of task results.
Returns:
str: A string object contains task results.
"""
ret
=
''
for
k
,
v
in
results
.
items
():
ret
+=
f
'
{
k
}
{
v
}
\n
'
return
ret
paddlespeech/cli/st/infer.py
浏览文件 @
3151637a
...
@@ -12,8 +12,10 @@
...
@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
argparse
import
argparse
import
ast
import
os
import
os
import
subprocess
import
subprocess
from
collections
import
OrderedDict
from
typing
import
List
from
typing
import
List
from
typing
import
Optional
from
typing
import
Optional
from
typing
import
Union
from
typing
import
Union
...
@@ -69,7 +71,7 @@ class STExecutor(BaseExecutor):
...
@@ -69,7 +71,7 @@ class STExecutor(BaseExecutor):
self
.
parser
=
argparse
.
ArgumentParser
(
self
.
parser
=
argparse
.
ArgumentParser
(
prog
=
"paddlespeech.st"
,
add_help
=
True
)
prog
=
"paddlespeech.st"
,
add_help
=
True
)
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
"--input"
,
type
=
str
,
required
=
Tru
e
,
help
=
"Audio file to translate."
)
"--input"
,
type
=
str
,
default
=
Non
e
,
help
=
"Audio file to translate."
)
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
"--model"
,
"--model"
,
type
=
str
,
type
=
str
,
...
@@ -107,6 +109,11 @@ class STExecutor(BaseExecutor):
...
@@ -107,6 +109,11 @@ class STExecutor(BaseExecutor):
type
=
str
,
type
=
str
,
default
=
paddle
.
get_device
(),
default
=
paddle
.
get_device
(),
help
=
"Choose device to execute model inference."
)
help
=
"Choose device to execute model inference."
)
self
.
parser
.
add_argument
(
'--job_dump_result'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
'Save job result into file.'
)
def
_get_pretrained_path
(
self
,
tag
:
str
)
->
os
.
PathLike
:
def
_get_pretrained_path
(
self
,
tag
:
str
)
->
os
.
PathLike
:
"""
"""
...
@@ -319,17 +326,29 @@ class STExecutor(BaseExecutor):
...
@@ -319,17 +326,29 @@ class STExecutor(BaseExecutor):
sample_rate
=
parser_args
.
sample_rate
sample_rate
=
parser_args
.
sample_rate
config
=
parser_args
.
config
config
=
parser_args
.
config
ckpt_path
=
parser_args
.
ckpt_path
ckpt_path
=
parser_args
.
ckpt_path
audio_file
=
parser_args
.
input
device
=
parser_args
.
device
device
=
parser_args
.
device
job_dump_result
=
parser_args
.
job_dump_result
try
:
task_source
=
self
.
get_task_source
(
parser_args
.
input
)
res
=
self
(
audio_file
,
model
,
src_lang
,
tgt_lang
,
sample_rate
,
task_results
=
OrderedDict
()
config
,
ckpt_path
,
device
)
has_exceptions
=
False
logger
.
info
(
"ST Result: {}"
.
format
(
res
))
return
True
for
id_
,
input_
in
task_source
.
items
():
except
Exception
as
e
:
try
:
logger
.
exception
(
e
)
res
=
self
(
input_
,
model
,
src_lang
,
tgt_lang
,
sample_rate
,
config
,
ckpt_path
,
device
)
task_results
[
id_
]
=
res
except
Exception
as
e
:
has_exceptions
=
True
task_results
[
id_
]
=
f
'
{
e
.
__class__
.
__name__
}
:
{
e
}
'
self
.
process_task_results
(
parser_args
.
input
,
task_results
,
job_dump_result
)
if
has_exceptions
:
return
False
return
False
else
:
return
True
@
stats_wrapper
@
stats_wrapper
def
__call__
(
self
,
def
__call__
(
self
,
...
...
paddlespeech/cli/text/infer.py
浏览文件 @
3151637a
...
@@ -12,8 +12,10 @@
...
@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
argparse
import
argparse
import
ast
import
os
import
os
import
re
import
re
from
collections
import
OrderedDict
from
typing
import
List
from
typing
import
List
from
typing
import
Optional
from
typing
import
Optional
from
typing
import
Union
from
typing
import
Union
...
@@ -80,7 +82,7 @@ class TextExecutor(BaseExecutor):
...
@@ -80,7 +82,7 @@ class TextExecutor(BaseExecutor):
self
.
parser
=
argparse
.
ArgumentParser
(
self
.
parser
=
argparse
.
ArgumentParser
(
prog
=
'paddlespeech.text'
,
add_help
=
True
)
prog
=
'paddlespeech.text'
,
add_help
=
True
)
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
'--input'
,
type
=
str
,
required
=
Tru
e
,
help
=
'Input text.'
)
'--input'
,
type
=
str
,
default
=
Non
e
,
help
=
'Input text.'
)
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
'--task'
,
'--task'
,
type
=
str
,
type
=
str
,
...
@@ -119,6 +121,11 @@ class TextExecutor(BaseExecutor):
...
@@ -119,6 +121,11 @@ class TextExecutor(BaseExecutor):
type
=
str
,
type
=
str
,
default
=
paddle
.
get_device
(),
default
=
paddle
.
get_device
(),
help
=
'Choose device to execute model inference.'
)
help
=
'Choose device to execute model inference.'
)
self
.
parser
.
add_argument
(
'--job_dump_result'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
'Save job result into file.'
)
def
_get_pretrained_path
(
self
,
tag
:
str
)
->
os
.
PathLike
:
def
_get_pretrained_path
(
self
,
tag
:
str
)
->
os
.
PathLike
:
"""
"""
...
@@ -256,7 +263,6 @@ class TextExecutor(BaseExecutor):
...
@@ -256,7 +263,6 @@ class TextExecutor(BaseExecutor):
"""
"""
parser_args
=
self
.
parser
.
parse_args
(
argv
)
parser_args
=
self
.
parser
.
parse_args
(
argv
)
text
=
parser_args
.
input
task
=
parser_args
.
task
task
=
parser_args
.
task
model_type
=
parser_args
.
model
model_type
=
parser_args
.
model
lang
=
parser_args
.
lang
lang
=
parser_args
.
lang
...
@@ -264,15 +270,28 @@ class TextExecutor(BaseExecutor):
...
@@ -264,15 +270,28 @@ class TextExecutor(BaseExecutor):
ckpt_path
=
parser_args
.
ckpt_path
ckpt_path
=
parser_args
.
ckpt_path
punc_vocab
=
parser_args
.
punc_vocab
punc_vocab
=
parser_args
.
punc_vocab
device
=
parser_args
.
device
device
=
parser_args
.
device
job_dump_result
=
parser_args
.
job_dump_result
try
:
task_source
=
self
.
get_task_source
(
parser_args
.
input
)
res
=
self
(
text
,
task
,
model_type
,
lang
,
cfg_path
,
ckpt_path
,
task_results
=
OrderedDict
()
punc_vocab
,
device
)
has_exceptions
=
False
logger
.
info
(
'Text Result:
\n
{}'
.
format
(
res
))
return
True
for
id_
,
input_
in
task_source
.
items
():
except
Exception
as
e
:
try
:
logger
.
exception
(
e
)
res
=
self
(
input_
,
task
,
model_type
,
lang
,
cfg_path
,
ckpt_path
,
punc_vocab
,
device
)
task_results
[
id_
]
=
res
except
Exception
as
e
:
has_exceptions
=
True
task_results
[
id_
]
=
f
'
{
e
.
__class__
.
__name__
}
:
{
e
}
'
self
.
process_task_results
(
parser_args
.
input
,
task_results
,
job_dump_result
)
if
has_exceptions
:
return
False
return
False
else
:
return
True
@
stats_wrapper
@
stats_wrapper
def
__call__
(
def
__call__
(
...
...
paddlespeech/cli/tts/infer.py
浏览文件 @
3151637a
...
@@ -12,7 +12,9 @@
...
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
argparse
import
argparse
import
ast
import
os
import
os
from
collections
import
OrderedDict
from
typing
import
Any
from
typing
import
Any
from
typing
import
List
from
typing
import
List
from
typing
import
Optional
from
typing
import
Optional
...
@@ -298,7 +300,7 @@ class TTSExecutor(BaseExecutor):
...
@@ -298,7 +300,7 @@ class TTSExecutor(BaseExecutor):
self
.
parser
=
argparse
.
ArgumentParser
(
self
.
parser
=
argparse
.
ArgumentParser
(
prog
=
'paddlespeech.tts'
,
add_help
=
True
)
prog
=
'paddlespeech.tts'
,
add_help
=
True
)
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
'--input'
,
type
=
str
,
required
=
Tru
e
,
help
=
'Input text to generate.'
)
'--input'
,
type
=
str
,
default
=
Non
e
,
help
=
'Input text to generate.'
)
# acoustic model
# acoustic model
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
'--am'
,
'--am'
,
...
@@ -397,6 +399,11 @@ class TTSExecutor(BaseExecutor):
...
@@ -397,6 +399,11 @@ class TTSExecutor(BaseExecutor):
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
'output.wav'
,
help
=
'output file name'
)
'--output'
,
type
=
str
,
default
=
'output.wav'
,
help
=
'output file name'
)
self
.
parser
.
add_argument
(
'--job_dump_result'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
'Save job result into file.'
)
def
_get_pretrained_path
(
self
,
tag
:
str
)
->
os
.
PathLike
:
def
_get_pretrained_path
(
self
,
tag
:
str
)
->
os
.
PathLike
:
"""
"""
...
@@ -671,7 +678,6 @@ class TTSExecutor(BaseExecutor):
...
@@ -671,7 +678,6 @@ class TTSExecutor(BaseExecutor):
args
=
self
.
parser
.
parse_args
(
argv
)
args
=
self
.
parser
.
parse_args
(
argv
)
text
=
args
.
input
am
=
args
.
am
am
=
args
.
am
am_config
=
args
.
am_config
am_config
=
args
.
am_config
am_ckpt
=
args
.
am_ckpt
am_ckpt
=
args
.
am_ckpt
...
@@ -686,35 +692,53 @@ class TTSExecutor(BaseExecutor):
...
@@ -686,35 +692,53 @@ class TTSExecutor(BaseExecutor):
voc_stat
=
args
.
voc_stat
voc_stat
=
args
.
voc_stat
lang
=
args
.
lang
lang
=
args
.
lang
device
=
args
.
device
device
=
args
.
device
output
=
args
.
output
spk_id
=
args
.
spk_id
spk_id
=
args
.
spk_id
job_dump_result
=
args
.
job_dump_result
try
:
task_source
=
self
.
get_task_source
(
args
.
input
)
res
=
self
(
task_results
=
OrderedDict
()
text
=
text
,
has_exceptions
=
False
# acoustic model related
am
=
am
,
for
id_
,
input_
in
task_source
.
items
():
am_config
=
am_config
,
if
len
(
task_source
)
>
1
:
am_ckpt
=
am_ckpt
,
assert
isinstance
(
args
.
output
,
am_stat
=
am_stat
,
str
)
and
args
.
output
.
endswith
(
'.wav'
)
phones_dict
=
phones_dict
,
output
=
args
.
output
.
replace
(
'.wav'
,
f
'_
{
id_
}
.wav'
)
tones_dict
=
tones_dict
,
else
:
speaker_dict
=
speaker_dict
,
output
=
args
.
output
spk_id
=
spk_id
,
# vocoder related
try
:
voc
=
voc
,
res
=
self
(
voc_config
=
voc_config
,
text
=
input_
,
voc_ckpt
=
voc_ckpt
,
# acoustic model related
voc_stat
=
voc_stat
,
am
=
am
,
# other
am_config
=
am_config
,
lang
=
lang
,
am_ckpt
=
am_ckpt
,
device
=
device
,
am_stat
=
am_stat
,
output
=
output
)
phones_dict
=
phones_dict
,
logger
.
info
(
'Wave file has been generated: {}'
.
format
(
res
))
tones_dict
=
tones_dict
,
return
True
speaker_dict
=
speaker_dict
,
except
Exception
as
e
:
spk_id
=
spk_id
,
logger
.
exception
(
e
)
# vocoder related
voc
=
voc
,
voc_config
=
voc_config
,
voc_ckpt
=
voc_ckpt
,
voc_stat
=
voc_stat
,
# other
lang
=
lang
,
device
=
device
,
output
=
output
)
task_results
[
id_
]
=
res
except
Exception
as
e
:
has_exceptions
=
True
task_results
[
id_
]
=
f
'
{
e
.
__class__
.
__name__
}
:
{
e
}
'
self
.
process_task_results
(
args
.
input
,
task_results
,
job_dump_result
)
if
has_exceptions
:
return
False
return
False
else
:
return
True
@
stats_wrapper
@
stats_wrapper
def
__call__
(
self
,
def
__call__
(
self
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录