Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
b585684b
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b585684b
编写于
8月 24, 2021
作者:
H
huangyuxin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add function: test export
上级
2d3b2aed
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
447 addition
and
1 deletion
+447
-1
deepspeech/exps/deepspeech2/bin/test_export.py
deepspeech/exps/deepspeech2/bin/test_export.py
+52
-0
deepspeech/exps/deepspeech2/model.py
deepspeech/exps/deepspeech2/model.py
+331
-1
deepspeech/models/ds2_online/conv.py
deepspeech/models/ds2_online/conv.py
+2
-0
deepspeech/models/ds2_online/deepspeech2.py
deepspeech/models/ds2_online/deepspeech2.py
+18
-0
examples/aishell/s0/local/test_export.sh
examples/aishell/s0/local/test_export.sh
+39
-0
examples/aishell/s0/run.sh
examples/aishell/s0/run.sh
+5
-0
未找到文件。
deepspeech/exps/deepspeech2/bin/test_export.py
0 → 100644
浏览文件 @
b585684b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for DeepSpeech2 model."""
from
deepspeech.exps.deepspeech2.config
import
get_cfg_defaults
from
deepspeech.exps.deepspeech2.model
import
DeepSpeech2ExportTester
as
ExportTester
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.utils.utility
import
print_arguments
def
main_sp
(
config
,
args
):
exp
=
ExportTester
(
config
,
args
)
exp
.
setup
()
exp
.
run_test
()
def
main
(
config
,
args
):
main_sp
(
config
,
args
)
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
.
add_argument
(
"--model_type"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
if
args
.
model_type
is
None
:
args
.
model_type
=
'offline'
print
(
"model_type:{}"
.
format
(
args
.
model_type
))
# https://yaml.org/type/float.html
config
=
get_cfg_defaults
(
args
.
model_type
)
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
if
args
.
opts
:
config
.
merge_from_list
(
args
.
opts
)
config
.
freeze
()
print
(
config
)
if
args
.
dump_config
:
with
open
(
args
.
dump_config
,
'w'
)
as
f
:
print
(
config
,
file
=
f
)
main
(
config
,
args
)
deepspeech/exps/deepspeech2/model.py
浏览文件 @
b585684b
...
@@ -20,6 +20,7 @@ from typing import Optional
...
@@ -20,6 +20,7 @@ from typing import Optional
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
from
paddle
import
inference
from
paddle.io
import
DataLoader
from
paddle.io
import
DataLoader
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
...
@@ -145,7 +146,7 @@ class DeepSpeech2Trainer(Trainer):
...
@@ -145,7 +146,7 @@ class DeepSpeech2Trainer(Trainer):
learning_rate
=
config
.
training
.
lr
,
learning_rate
=
config
.
training
.
lr
,
gamma
=
config
.
training
.
lr_decay
,
gamma
=
config
.
training
.
lr_decay
,
verbose
=
True
)
verbose
=
True
)
optimizer
=
paddle
.
optimizer
.
Adam
(
optimizer
=
paddle
.
optimizer
.
SGD
(
#Adam
learning_rate
=
lr_scheduler
,
learning_rate
=
lr_scheduler
,
parameters
=
model
.
parameters
(),
parameters
=
model
.
parameters
(),
weight_decay
=
paddle
.
regularizer
.
L2Decay
(
weight_decay
=
paddle
.
regularizer
.
L2Decay
(
...
@@ -395,3 +396,332 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
...
@@ -395,3 +396,332 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
output_dir
=
output_dir
self
.
output_dir
=
output_dir
class
DeepSpeech2ExportTester
(
DeepSpeech2Trainer
):
@
classmethod
def
params
(
cls
,
config
:
Optional
[
CfgNode
]
=
None
)
->
CfgNode
:
# testing config
default
=
CfgNode
(
dict
(
alpha
=
2.5
,
# Coef of LM for beam search.
beta
=
0.3
,
# Coef of WC for beam search.
cutoff_prob
=
1.0
,
# Cutoff probability for pruning.
cutoff_top_n
=
40
,
# Cutoff number for pruning.
lang_model_path
=
'models/lm/common_crawl_00.prune01111.trie.klm'
,
# Filepath for language model.
decoding_method
=
'ctc_beam_search'
,
# Decoding method. Options: ctc_beam_search, ctc_greedy
error_rate_type
=
'wer'
,
# Error rate type for evaluation. Options `wer`, 'cer'
num_proc_bsearch
=
8
,
# # of CPUs for beam search.
beam_size
=
500
,
# Beam search width.
batch_size
=
128
,
# decoding batch size
))
if
config
is
not
None
:
config
.
merge_from_other_cfg
(
default
)
return
default
def
__init__
(
self
,
config
,
args
):
super
().
__init__
(
config
,
args
)
def
ordid2token
(
self
,
texts
,
texts_len
):
""" ord() id to chr() chr """
trans
=
[]
for
text
,
n
in
zip
(
texts
,
texts_len
):
n
=
n
.
numpy
().
item
()
ids
=
text
[:
n
]
trans
.
append
(
''
.
join
([
chr
(
i
)
for
i
in
ids
]))
return
trans
def
compute_metrics
(
self
,
utts
,
audio
,
audio_len
,
texts
,
texts_len
,
fout
=
None
):
cfg
=
self
.
config
.
decoding
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
errors_func
=
error_rate
.
char_errors
if
cfg
.
error_rate_type
==
'cer'
else
error_rate
.
word_errors
error_rate_func
=
error_rate
.
cer
if
cfg
.
error_rate_type
==
'cer'
else
error_rate
.
wer
vocab_list
=
self
.
test_loader
.
collate_fn
.
vocab_list
batch_size
=
self
.
config
.
decoding
.
batch_size
output_prob_list
=
[]
output_lens_list
=
[]
decoder_chunk_size
=
8
subsampling_rate
=
self
.
model
.
encoder
.
conv
.
subsampling_rate
receptive_field_length
=
self
.
model
.
encoder
.
conv
.
receptive_field_length
chunk_stride
=
subsampling_rate
*
decoder_chunk_size
chunk_size
=
(
decoder_chunk_size
-
1
)
*
subsampling_rate
+
receptive_field_length
x_batch
=
audio
.
numpy
()
x_len_batch
=
audio_len
.
numpy
().
astype
(
np
.
int64
)
max_len_batch
=
x_batch
.
shape
[
1
]
batch_padding_len
=
chunk_stride
-
(
max_len_batch
-
chunk_size
)
%
chunk_stride
# The length of padding for the batch
x_list
=
np
.
split
(
x_batch
,
x_batch
.
shape
[
0
],
axis
=
0
)
x_len_list
=
np
.
split
(
x_len_batch
,
x_batch
.
shape
[
0
],
axis
=
0
)
for
x
,
x_len
in
zip
(
x_list
,
x_len_list
):
assert
(
chunk_size
<=
x_len
[
0
])
eouts_chunk_list
=
[]
eouts_chunk_lens_list
=
[]
padding_len_x
=
chunk_stride
-
(
x_len
[
0
]
-
chunk_size
)
%
chunk_stride
padding
=
np
.
zeros
(
(
x
.
shape
[
0
],
padding_len_x
,
x
.
shape
[
2
]),
dtype
=
np
.
float32
)
padded_x
=
np
.
concatenate
([
x
,
padding
],
axis
=
1
)
num_chunk
=
(
x_len
[
0
]
+
padding_len_x
-
chunk_size
)
/
chunk_stride
+
1
num_chunk
=
int
(
num_chunk
)
chunk_state_h_box
=
np
.
zeros
(
(
self
.
config
.
model
.
num_rnn_layers
,
1
,
self
.
config
.
model
.
rnn_layer_size
),
dtype
=
np
.
float32
)
chunk_state_c_box
=
np
.
zeros
(
(
self
.
config
.
model
.
num_rnn_layers
,
1
,
self
.
config
.
model
.
rnn_layer_size
),
dtype
=
np
.
float32
)
input_names
=
self
.
predictor
.
get_input_names
()
audio_handle
=
self
.
predictor
.
get_input_handle
(
input_names
[
0
])
audio_len_handle
=
self
.
predictor
.
get_input_handle
(
input_names
[
1
])
h_box_handle
=
self
.
predictor
.
get_input_handle
(
input_names
[
2
])
c_box_handle
=
self
.
predictor
.
get_input_handle
(
input_names
[
3
])
probs_chunk_list
=
[]
probs_chunk_lens_list
=
[]
for
i
in
range
(
0
,
num_chunk
):
start
=
i
*
chunk_stride
end
=
start
+
chunk_size
x_chunk
=
padded_x
[:,
start
:
end
,
:]
x_len_left
=
np
.
where
(
x_len
-
i
*
chunk_stride
<
0
,
np
.
zeros_like
(
x_len
,
dtype
=
np
.
int64
),
x_len
-
i
*
chunk_stride
)
x_chunk_len_tmp
=
np
.
ones_like
(
x_len
,
dtype
=
np
.
int64
)
*
chunk_size
x_chunk_lens
=
np
.
where
(
x_len_left
<
x_chunk_len_tmp
,
x_len_left
,
x_chunk_len_tmp
)
if
(
x_chunk_lens
[
0
]
<
receptive_field_length
):
#means the number of input frames in the chunk is not enough for predicting one prob
break
audio_handle
.
reshape
(
x_chunk
.
shape
)
audio_handle
.
copy_from_cpu
(
x_chunk
)
audio_len_handle
.
reshape
(
x_chunk_lens
.
shape
)
audio_len_handle
.
copy_from_cpu
(
x_chunk_lens
)
h_box_handle
.
reshape
(
chunk_state_h_box
.
shape
)
h_box_handle
.
copy_from_cpu
(
chunk_state_h_box
)
c_box_handle
.
reshape
(
chunk_state_c_box
.
shape
)
c_box_handle
.
copy_from_cpu
(
chunk_state_c_box
)
output_names
=
self
.
predictor
.
get_output_names
()
output_handle
=
self
.
predictor
.
get_output_handle
(
output_names
[
0
])
output_lens_handle
=
self
.
predictor
.
get_output_handle
(
output_names
[
1
])
output_state_h_handle
=
self
.
predictor
.
get_output_handle
(
output_names
[
2
])
output_state_c_handle
=
self
.
predictor
.
get_output_handle
(
output_names
[
3
])
self
.
predictor
.
run
()
output_chunk_prob
=
output_handle
.
copy_to_cpu
()
output_chunk_lens
=
output_lens_handle
.
copy_to_cpu
()
chunk_state_h_box
=
output_state_h_handle
.
copy_to_cpu
()
chunk_state_c_box
=
output_state_c_handle
.
copy_to_cpu
()
output_chunk_prob
=
paddle
.
to_tensor
(
output_chunk_prob
)
output_chunk_lens
=
paddle
.
to_tensor
(
output_chunk_lens
)
probs_chunk_list
.
append
(
output_chunk_prob
)
probs_chunk_lens_list
.
append
(
output_chunk_lens
)
output_prob
=
paddle
.
concat
(
probs_chunk_list
,
axis
=
1
)
output_lens
=
paddle
.
add_n
(
probs_chunk_lens_list
)
output_prob_padding_len
=
max_len_batch
+
batch_padding_len
-
output_prob
.
shape
[
1
]
output_prob_padding
=
paddle
.
zeros
(
(
1
,
output_prob_padding_len
,
output_prob
.
shape
[
2
]),
dtype
=
"float32"
)
# The prob padding for a piece of utterance
output_prob
=
paddle
.
concat
(
[
output_prob
,
output_prob_padding
],
axis
=
1
)
output_prob_list
.
append
(
output_prob
)
output_lens_list
.
append
(
output_lens
)
output_prob_branch
=
paddle
.
concat
(
output_prob_list
,
axis
=
0
)
output_lens_branch
=
paddle
.
concat
(
output_lens_list
,
axis
=
0
)
"""
x = audio.numpy()
x_len = audio_len.numpy().astype(np.int64)
input_names = self.predictor.get_input_names()
audio_handle = self.predictor.get_input_handle(input_names[0])
audio_len_handle = self.predictor.get_input_handle(input_names[1])
h_box_handle = self.predictor.get_input_handle(input_names[2])
c_box_handle = self.predictor.get_input_handle(input_names[3])
audio_handle.reshape(x.shape)
audio_handle.copy_from_cpu(x)
audio_len_handle.reshape(x_len.shape)
audio_len_handle.copy_from_cpu(x_len)
init_state_h_box = np.zeros((self.config.model.num_rnn_layers, audio.shape[0], self.config.model.rnn_layer_size), dtype=np.float32)
init_state_c_box = np.zeros((self.config.model.num_rnn_layers, audio.shape[0], self.config.model.rnn_layer_size), dtype=np.float32)
h_box_handle.reshape(init_state_h_box.shape)
h_box_handle.copy_from_cpu(init_state_h_box)
c_box_handle.reshape(init_state_c_box.shape)
c_box_handle.copy_from_cpu(init_state_c_box)
#self.autolog.times.start()
#self.autolog.times.stamp()
self.predictor.run()
output_names = self.predictor.get_output_names()
output_handle = self.predictor.get_output_handle(output_names[0])
output_lens_handle = self.predictor.get_output_handle(output_names[1])
output_state_h_handle = self.predictor.get_output_handle(output_names[2])
output_state_c_handle = self.predictor.get_output_handle(output_names[3])
output_prob = output_handle.copy_to_cpu()
output_lens = output_lens_handle.copy_to_cpu()
output_stata_h_box = output_state_h_handle.copy_to_cpu()
output_stata_c_box = output_state_c_handle.copy_to_cpu()
output_prob_branch = paddle.to_tensor(output_prob)
output_lens_branch = paddle.to_tensor(output_lens)
"""
result_transcripts
=
self
.
model
.
decode_by_probs
(
output_prob_branch
,
output_lens_branch
,
vocab_list
,
decoding_method
=
cfg
.
decoding_method
,
lang_model_path
=
cfg
.
lang_model_path
,
beam_alpha
=
cfg
.
alpha
,
beam_beta
=
cfg
.
beta
,
beam_size
=
cfg
.
beam_size
,
cutoff_prob
=
cfg
.
cutoff_prob
,
cutoff_top_n
=
cfg
.
cutoff_top_n
,
num_processes
=
cfg
.
num_proc_bsearch
)
#self.autolog.times.stamp()
#self.autolog.times.stamp()
#self.autolog.times.end()
target_transcripts
=
self
.
ordid2token
(
texts
,
texts_len
)
for
utt
,
target
,
result
in
zip
(
utts
,
target_transcripts
,
result_transcripts
):
errors
,
len_ref
=
errors_func
(
target
,
result
)
errors_sum
+=
errors
len_refs
+=
len_ref
num_ins
+=
1
if
fout
:
fout
.
write
(
utt
+
" "
+
result
+
"
\n
"
)
logger
.
info
(
"
\n
Target Transcription: %s
\n
Output Transcription: %s"
%
(
target
,
result
))
logger
.
info
(
"Current error rate [%s] = %f"
%
(
cfg
.
error_rate_type
,
error_rate_func
(
target
,
result
)))
return
dict
(
errors_sum
=
errors_sum
,
len_refs
=
len_refs
,
num_ins
=
num_ins
,
error_rate
=
errors_sum
/
len_refs
,
error_rate_type
=
cfg
.
error_rate_type
)
@
mp_tools
.
rank_zero_only
@
paddle
.
no_grad
()
def
test
(
self
):
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
#self.autolog = Autolog(
# batch_size=self.config.decoding.batch_size,
# model_name="deepspeech2",
# model_precision="fp32").getlog()
self
.
model
.
eval
()
cfg
=
self
.
config
error_rate_type
=
None
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
with
open
(
self
.
args
.
result_file
,
'w'
)
as
fout
:
for
i
,
batch
in
enumerate
(
self
.
test_loader
):
utts
,
audio
,
audio_len
,
texts
,
texts_len
=
batch
metrics
=
self
.
compute_metrics
(
utts
,
audio
,
audio_len
,
texts
,
texts_len
,
fout
)
errors_sum
+=
metrics
[
'errors_sum'
]
len_refs
+=
metrics
[
'len_refs'
]
num_ins
+=
metrics
[
'num_ins'
]
error_rate_type
=
metrics
[
'error_rate_type'
]
logger
.
info
(
"Error rate [%s] (%d/?) = %f"
%
(
error_rate_type
,
num_ins
,
errors_sum
/
len_refs
))
# logging
msg
=
"Test: "
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"Final error rate [%s] (%d/%d) = %f"
%
(
error_rate_type
,
num_ins
,
num_ins
,
errors_sum
/
len_refs
)
logger
.
info
(
msg
)
#self.autolog.report()
def
run_test
(
self
):
try
:
self
.
test
()
except
KeyboardInterrupt
:
exit
(
-
1
)
def
run_export
(
self
):
try
:
self
.
export
()
except
KeyboardInterrupt
:
exit
(
-
1
)
def
setup
(
self
):
"""Setup the experiment.
"""
paddle
.
set_device
(
self
.
args
.
device
)
self
.
setup_output_dir
()
#self.setup_checkpointer()
self
.
setup_dataloader
()
self
.
setup_model
()
self
.
iteration
=
0
self
.
epoch
=
0
def
setup_output_dir
(
self
):
"""Create a directory used for output.
"""
# output dir
if
self
.
args
.
output
:
output_dir
=
Path
(
self
.
args
.
output
).
expanduser
()
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
else
:
output_dir
=
Path
(
self
.
args
.
export_path
).
expanduser
().
parent
.
parent
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
output_dir
=
output_dir
def
setup_model
(
self
):
super
().
setup_model
()
if
self
.
args
.
model_type
==
'online'
:
#inference_dir = "exp/deepspeech2_online/checkpoints/"
#inference_dir = "exp/deepspeech2_online_3rr_1fc_lr_decay0.91_lstm/checkpoints/"
#speedyspeech_config = inference.Config(
# str(Path(inference_dir) / "avg_1.jit.pdmodel"),
# str(Path(inference_dir) / "avg_1.jit.pdiparams"))
speedyspeech_config
=
inference
.
Config
(
self
.
args
.
export_path
+
".pdmodel"
,
self
.
args
.
export_path
+
".pdiparams"
)
speedyspeech_config
.
enable_use_gpu
(
100
,
0
)
speedyspeech_config
.
enable_memory_optim
()
speedyspeech_predictor
=
inference
.
create_predictor
(
speedyspeech_config
)
self
.
predictor
=
speedyspeech_predictor
deepspeech/models/ds2_online/conv.py
浏览文件 @
b585684b
...
@@ -30,4 +30,6 @@ class Conv2dSubsampling4Online(Conv2dSubsampling4):
...
@@ -30,4 +30,6 @@ class Conv2dSubsampling4Online(Conv2dSubsampling4):
#b, c, t, f = paddle.shape(x) #not work under jit
#b, c, t, f = paddle.shape(x) #not work under jit
x
=
x
.
transpose
([
0
,
2
,
1
,
3
]).
reshape
([
0
,
0
,
-
1
])
x
=
x
.
transpose
([
0
,
2
,
1
,
3
]).
reshape
([
0
,
0
,
-
1
])
x_len
=
((
x_len
-
1
)
//
2
-
1
)
//
2
x_len
=
((
x_len
-
1
)
//
2
-
1
)
//
2
x_len
=
paddle
.
where
(
x_len
>=
0
,
x_len
,
paddle
.
zeros_like
(
x_len
.
shape
,
"int64"
))
return
x
,
x_len
return
x
,
x_len
deepspeech/models/ds2_online/deepspeech2.py
浏览文件 @
b585684b
...
@@ -325,6 +325,24 @@ class DeepSpeech2ModelOnline(nn.Layer):
...
@@ -325,6 +325,24 @@ class DeepSpeech2ModelOnline(nn.Layer):
lang_model_path
,
beam_alpha
,
beam_beta
,
beam_size
,
cutoff_prob
,
lang_model_path
,
beam_alpha
,
beam_beta
,
beam_size
,
cutoff_prob
,
cutoff_top_n
,
num_processes
)
cutoff_top_n
,
num_processes
)
@
paddle
.
no_grad
()
def
decode_by_probs
(
self
,
probs
,
probs_len
,
vocab_list
,
decoding_method
,
lang_model_path
,
beam_alpha
,
beam_beta
,
beam_size
,
cutoff_prob
,
cutoff_top_n
,
num_processes
):
# init once
# decoders only accept string encoded in utf-8
self
.
decoder
.
init_decode
(
beam_alpha
=
beam_alpha
,
beam_beta
=
beam_beta
,
lang_model_path
=
lang_model_path
,
vocab_list
=
vocab_list
,
decoding_method
=
decoding_method
)
return
self
.
decoder
.
decode_probs
(
probs
.
numpy
(),
probs_len
,
vocab_list
,
decoding_method
,
lang_model_path
,
beam_alpha
,
beam_beta
,
beam_size
,
cutoff_prob
,
cutoff_top_n
,
num_processes
)
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
dataloader
,
config
,
checkpoint_path
):
def
from_pretrained
(
cls
,
dataloader
,
config
,
checkpoint_path
):
"""Build a DeepSpeech2Model model from a pretrained model.
"""Build a DeepSpeech2Model model from a pretrained model.
...
...
examples/aishell/s0/local/test_export.sh
0 → 100755
浏览文件 @
b585684b
#!/bin/bash
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix model_type"
exit
-1
fi
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
echo
"using
$ngpu
gpus..."
device
=
gpu
if
[
${
ngpu
}
==
0
]
;
then
device
=
cpu
fi
config_path
=
$1
jit_model_export_path
=
$2
model_type
=
$3
# download language model
bash
local
/download_lm_ch.sh
if
[
$?
-ne
0
]
;
then
exit
1
fi
python3
-u
${
BIN_DIR
}
/test_export.py
\
--device
${
device
}
\
--nproc
1
\
--config
${
config_path
}
\
--result_file
${
ckpt_prefix
}
.rsl
\
--export_path
${
jit_model_export_path
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in evaluation!"
exit
1
fi
exit
0
examples/aishell/s0/run.sh
浏览文件 @
b585684b
...
@@ -39,3 +39,8 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
...
@@ -39,3 +39,8 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n
# export ckpt avg_n
CUDA_VISIBLE_DEVICES
=
0 ./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
${
model_type
}
CUDA_VISIBLE_DEVICES
=
0 ./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
${
model_type
}
fi
fi
if
[
${
stage
}
-le
5
]
&&
[
${
stop_stage
}
-ge
5
]
;
then
# test export ckpt avg_n
CUDA_VISIBLE_DEVICES
=
0 ./local/test_export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
${
model_type
}
||
exit
-1
fi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录