Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
05288fe1
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
05288fe1
编写于
2月 21, 2022
作者:
K
KP
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update batch input and stdin input.
上级
1818b058
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
95 addition
and
49 deletion
+95
-49
paddlespeech/cli/cls/infer.py
paddlespeech/cli/cls/infer.py
+24
-39
paddlespeech/cli/executor.py
paddlespeech/cli/executor.py
+71
-10
未找到文件。
paddlespeech/cli/cls/infer.py
浏览文件 @
05288fe1
...
...
@@ -14,7 +14,7 @@
import
argparse
import
ast
import
os
import
sys
from
collections
import
OrderedDict
from
typing
import
List
from
typing
import
Optional
from
typing
import
Union
...
...
@@ -79,7 +79,7 @@ class CLSExecutor(BaseExecutor):
self
.
parser
=
argparse
.
ArgumentParser
(
prog
=
'paddlespeech.cls'
,
add_help
=
True
)
self
.
parser
.
add_argument
(
'--input'
,
type
=
str
,
required
=
Tru
e
,
help
=
'Audio file to classify.'
)
'--input'
,
type
=
str
,
default
=
Non
e
,
help
=
'Audio file to classify.'
)
self
.
parser
.
add_argument
(
'--model'
,
type
=
str
,
...
...
@@ -221,7 +221,7 @@ class CLSExecutor(BaseExecutor):
ret
=
''
for
idx
in
topk_idx
:
label
,
score
=
self
.
_label_list
[
idx
],
result
[
idx
]
ret
+=
f
'
{
label
}
{
score
}
\n
'
ret
+=
f
'
{
label
}
{
score
}
'
return
ret
def
postprocess
(
self
,
topk
:
int
)
->
Union
[
str
,
os
.
PathLike
]:
...
...
@@ -241,36 +241,34 @@ class CLSExecutor(BaseExecutor):
label_file
=
parser_args
.
label_file
cfg_path
=
parser_args
.
config
ckpt_path
=
parser_args
.
ckpt_path
input_file
=
parser_args
.
input
topk
=
parser_args
.
topk
device
=
parser_args
.
device
job_dump_result
=
parser_args
.
job_dump_result
try
:
if
job_dump_result
:
assert
self
.
_is_job_input
(
input_file
),
'Input file should be a job file(*.job) when `job_dump_result` is True.'
job_output_file
=
os
.
path
.
abspath
(
input_file
)
+
'.done'
sys
.
stdout
=
open
(
job_output_file
,
'w'
)
task_source
=
self
.
get_task_source
(
parser_args
.
input
)
task_results
=
OrderedDict
()
has_exceptions
=
False
print
(
self
(
input_file
,
model_type
,
cfg_path
,
ckpt_path
,
label_file
,
topk
,
device
))
for
id_
,
input_
in
task_source
.
items
():
try
:
res
=
self
(
input_
,
model_type
,
cfg_path
,
ckpt_path
,
label_file
,
topk
,
device
)
task_results
[
id_
]
=
res
except
Exception
as
e
:
has_exceptions
=
True
task_results
[
id_
]
=
f
'
{
e
.
__class__
.
__name__
}
:
{
e
}
'
if
job_dump_result
:
logger
.
info
(
f
'Results had been saved to:
{
job_output_file
}
'
)
self
.
process_task_results
(
parser_args
.
input
,
task_results
,
job_dump_result
)
return
True
except
Exception
as
e
:
logger
.
exception
(
e
)
if
has_exceptions
:
return
False
finally
:
sys
.
stdout
.
close
()
else
:
return
True
@
stats_wrapper
def
__call__
(
self
,
input
_file
:
os
.
PathLike
,
audio
_file
:
os
.
PathLike
,
model
:
str
=
'panns_cnn14'
,
config
:
Optional
[
os
.
PathLike
]
=
None
,
ckpt_path
:
Optional
[
os
.
PathLike
]
=
None
,
...
...
@@ -280,24 +278,11 @@ class CLSExecutor(BaseExecutor):
"""
Python API to call an executor.
"""
input_file
=
os
.
path
.
abspath
(
input_file
)
audio_file
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
audio_file
)
)
paddle
.
set_device
(
device
)
self
.
_init_from_path
(
model
,
config
,
ckpt_path
,
label_file
)
if
self
.
_is_job_input
(
input_file
):
# *.job
job_outputs
=
{}
job_contents
=
self
.
_job_preprocess
(
input_file
)
for
id_
,
file
in
job_contents
.
items
():
try
:
self
.
preprocess
(
file
)
self
.
infer
()
job_outputs
[
id_
]
=
self
.
postprocess
(
topk
).
strip
()
except
Exception
as
e
:
job_outputs
[
id_
]
=
f
'
{
e
.
__class__
.
__name__
}
:
{
e
}
'
res
=
self
.
_job_postprecess
(
job_outputs
)
else
:
self
.
preprocess
(
input_file
)
self
.
infer
()
res
=
self
.
postprocess
(
topk
)
# Retrieve result of cls.
self
.
preprocess
(
audio_file
)
self
.
infer
()
res
=
self
.
postprocess
(
topk
)
# Retrieve result of cls.
return
res
paddlespeech/cli/executor.py
浏览文件 @
05288fe1
...
...
@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
from
abc
import
ABC
from
abc
import
abstractmethod
from
collections
import
OrderedDict
from
typing
import
Any
from
typing
import
Dict
from
typing
import
List
...
...
@@ -21,6 +23,8 @@ from typing import Union
import
paddle
from
.log
import
logger
class
BaseExecutor
(
ABC
):
"""
...
...
@@ -28,8 +32,8 @@ class BaseExecutor(ABC):
"""
def
__init__
(
self
):
self
.
_inputs
=
d
ict
()
self
.
_outputs
=
d
ict
()
self
.
_inputs
=
OrderedD
ict
()
self
.
_outputs
=
OrderedD
ict
()
@
abstractmethod
def
_get_pretrained_path
(
self
,
tag
:
str
)
->
os
.
PathLike
:
...
...
@@ -102,6 +106,61 @@ class BaseExecutor(ABC):
"""
pass
def
get_task_source
(
self
,
input_
:
Union
[
str
,
os
.
PathLike
,
None
]
)
->
Dict
[
str
,
Union
[
str
,
os
.
PathLike
]]:
"""
Get task input source from command line input.
Args:
input_ (Union[str, os.PathLike, None]): Input from command line.
Returns:
Dict[str, Union[str, os.PathLike]]: A dict with ids and inputs.
"""
if
self
.
_is_job_input
(
input_
):
ret
=
self
.
_get_job_contents
(
input_
)
else
:
ret
=
OrderedDict
()
if
input_
is
None
:
# Take input from stdin
for
i
,
line
in
enumerate
(
sys
.
stdin
):
line
=
line
.
strip
()
if
len
(
line
.
split
(
' '
))
==
1
:
ret
[
str
(
i
+
1
)]
=
line
elif
len
(
line
.
split
(
' '
))
==
2
:
id_
,
info
=
line
.
split
(
' '
)
ret
[
id_
]
=
info
else
:
# No valid input info from one line.
continue
else
:
ret
[
1
]
=
input_
return
ret
def
process_task_results
(
self
,
input_
:
Union
[
str
,
os
.
PathLike
,
None
],
results
:
Dict
[
str
,
os
.
PathLike
],
job_dump_result
:
bool
=
False
):
"""
Handling task results and redirect stdout if needed.
Args:
input_ (Union[str, os.PathLike, None]): Input from command line.
results (Dict[str, os.PathLike]): Task outputs.
job_dump_result (bool, optional): if True, dumps job results into file. Defaults to False.
"""
raw_text
=
self
.
_format_task_results
(
results
)
print
(
raw_text
,
end
=
''
)
if
self
.
_is_job_input
(
input_
)
and
job_dump_result
:
try
:
job_output_file
=
os
.
path
.
abspath
(
input_
)
+
'.done'
sys
.
stdout
=
open
(
job_output_file
,
'w'
)
print
(
raw_text
,
end
=
''
)
logger
.
info
(
f
'Results had been saved to:
{
job_output_file
}
'
)
finally
:
sys
.
stdout
.
close
()
def
_is_job_input
(
self
,
input_
:
Union
[
str
,
os
.
PathLike
])
->
bool
:
"""
Check if current input file is a job input or not.
...
...
@@ -112,9 +171,10 @@ class BaseExecutor(ABC):
Returns:
bool: return `True` for job input, `False` otherwise.
"""
return
os
.
path
.
isfile
(
input_
)
and
input_
.
endswith
(
'.job'
)
return
input_
and
os
.
path
.
isfile
(
input_
)
and
input_
.
endswith
(
'.job'
)
def
_job_preprocess
(
self
,
job_input
:
os
.
PathLike
)
->
Dict
[
str
,
str
]:
def
_get_job_contents
(
self
,
job_input
:
os
.
PathLike
)
->
Dict
[
str
,
Union
[
str
,
os
.
PathLike
]]:
"""
Read a job input file and return its contents in a dictionary.
...
...
@@ -124,7 +184,7 @@ class BaseExecutor(ABC):
Returns:
Dict[str, str]: Contents of job input.
"""
job_contents
=
{}
job_contents
=
OrderedDict
()
with
open
(
job_input
)
as
f
:
for
line
in
f
:
line
=
line
.
strip
()
...
...
@@ -134,17 +194,18 @@ class BaseExecutor(ABC):
job_contents
[
k
]
=
v
return
job_contents
def
_job_postprecess
(
self
,
job_outputs
:
Dict
[
str
,
str
])
->
str
:
def
_format_task_results
(
self
,
results
:
Dict
[
str
,
Union
[
str
,
os
.
PathLike
]])
->
str
:
"""
Convert
job results to string
.
Convert
task results to raw text
.
Args:
job_outputs (Dict[str, str]): A dictionary with job ids and
results.
results (Dict[str, str]): A dictionary of task
results.
Returns:
str: A string object contains
job outpu
ts.
str: A string object contains
task resul
ts.
"""
ret
=
''
for
k
,
v
in
job_outpu
ts
.
items
():
for
k
,
v
in
resul
ts
.
items
():
ret
+=
f
'
{
k
}
{
v
}
\n
'
return
ret
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录