Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
993d6783
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
993d6783
编写于
3月 09, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove unused code, test=doc
上级
0e87037f
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
1 addition
and
687 deletion
+1
-687
examples/voxceleb/sv0/local/data.sh
examples/voxceleb/sv0/local/data.sh
+0
-25
paddlespeech/vector/__init__.py
paddlespeech/vector/__init__.py
+1
-29
paddlespeech/vector/datasets/ark_dataset.py
paddlespeech/vector/datasets/ark_dataset.py
+0
-142
paddlespeech/vector/datasets/dataset.py
paddlespeech/vector/datasets/dataset.py
+0
-143
paddlespeech/vector/datasets/egs_dataset.py
paddlespeech/vector/datasets/egs_dataset.py
+0
-91
paddlespeech/vector/utils/data_utils.py
paddlespeech/vector/utils/data_utils.py
+0
-125
paddlespeech/vector/utils/utils.py
paddlespeech/vector/utils/utils.py
+0
-132
未找到文件。
examples/voxceleb/sv0/local/data.sh
已删除
100755 → 0
浏览文件 @
0e87037f
stage
=
-1
stop_stage
=
100
TARGET_DIR
=
${
MAIN_ROOT
}
/dataset
.
utils/parse_options.sh
||
exit
-1
;
src
=
$1
mkdir
-p
data/
{
dev,test
}
if
[
${
stage
}
-le
-1
]
&&
[
${
stop_stage
}
-ge
-1
]
;
then
# download data, generate manifests
# create data/{dev,test} directory to store the manifest files
python3
${
TARGET_DIR
}
/voxceleb/voxceleb1.py
\
--manifest_prefix
=
"data/manifest"
\
--target_dir
=
"
${
src
}
"
if
[
$?
-ne
0
]
;
then
echo
"Prepare Voxceleb failed. Terminated."
exit
1
fi
mv
data/manifest.dev data/dev
mv
data/voxceleb1.dev.meta data/dev
mv
data/manifest.test data/test
mv
data/voxceleb1.test.meta data/test
fi
paddlespeech/vector/__init__.py
浏览文件 @
993d6783
...
...
@@ -11,31 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
"""
__init__ file for sidt package.
"""
import
logging
as
sidt_logging
import
colorlog
LOG_COLOR_CONFIG
=
{
'DEBUG'
:
'white'
,
'INFO'
:
'white'
,
'WARNING'
:
'yellow'
,
'ERROR'
:
'red'
,
'CRITICAL'
:
'purple'
,
}
# 设置全局的logger
colored_formatter
=
colorlog
.
ColoredFormatter
(
'%(log_color)s [%(levelname)s] [%(asctime)s] [%(filename)s:%(lineno)d] - %(message)s'
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
,
log_colors
=
LOG_COLOR_CONFIG
)
# 日志输出格式
_logger
=
sidt_logging
.
getLogger
(
"sidt"
)
handler
=
colorlog
.
StreamHandler
()
handler
.
setLevel
(
sidt_logging
.
INFO
)
handler
.
setFormatter
(
colored_formatter
)
_logger
.
addHandler
(
handler
)
_logger
.
setLevel
(
sidt_logging
.
INFO
)
paddlespeech/vector/datasets/ark_dataset.py
已删除
100755 → 0
浏览文件 @
0e87037f
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
random
import
numpy
as
np
import
kaldi_python_io
as
k_io
from
paddle.io
import
Dataset
from
paddlespeech.vector.utils.data_utils
import
batch_pad_right
import
paddlespeech.vector.utils
as
utils
from
paddlespeech.vector.utils.utils
import
read_map_file
from
paddlespeech.vector
import
_logger
as
log
def
ark_collate_fn
(
batch
):
"""
Custom collate function] for kaldi feats dataset
Args:
min_chunk_size: min chunk size of a utterance
max_chunk_size: max chunk size of a utterance
Returns:
ark_collate_fn: collate funtion for dataloader
"""
data
=
[]
target
=
[]
for
items
in
batch
:
for
x
,
y
in
zip
(
items
[
0
],
items
[
1
]):
data
.
append
(
np
.
array
(
x
))
target
.
append
(
y
)
data
,
lengths
=
batch_pad_right
(
data
)
return
np
.
array
(
data
,
dtype
=
np
.
float32
),
\
np
.
array
(
lengths
,
dtype
=
np
.
float32
),
\
np
.
array
(
target
,
dtype
=
np
.
long
).
reshape
((
len
(
target
),
1
))
class
KaldiArkDataset
(
Dataset
):
"""
Dataset used to load kaldi ark/scp files.
"""
def
__init__
(
self
,
scp_file
,
label2utt
,
min_item_size
=
1
,
max_item_size
=
1
,
repeat
=
50
,
min_chunk_size
=
200
,
max_chunk_size
=
400
,
select_by_speaker
=
True
):
self
.
scp_file
=
scp_file
self
.
scp_reader
=
None
self
.
repeat
=
repeat
self
.
min_item_size
=
min_item_size
self
.
max_item_size
=
max_item_size
self
.
min_chunk_size
=
min_chunk_size
self
.
max_chunk_size
=
max_chunk_size
self
.
_collate_fn
=
ark_collate_fn
self
.
_is_select_by_speaker
=
select_by_speaker
if
utils
.
is_exist
(
self
.
scp_file
):
self
.
scp_reader
=
k_io
.
ScriptReader
(
self
.
scp_file
)
label2utts
,
utt2label
=
read_map_file
(
label2utt
,
key_func
=
int
)
self
.
utt_info
=
list
(
label2utts
.
items
())
if
self
.
_is_select_by_speaker
else
list
(
utt2label
.
items
())
@
property
def
collate_fn
(
self
):
"""
Return a collate funtion.
"""
return
self
.
_collate_fn
def
_random_chunk
(
self
,
length
):
chunk_size
=
random
.
randint
(
self
.
min_chunk_size
,
self
.
max_chunk_size
)
if
chunk_size
>=
length
:
return
0
,
length
start
=
random
.
randint
(
0
,
length
-
chunk_size
)
end
=
start
+
chunk_size
return
start
,
end
def
_select_by_speaker
(
self
,
index
):
if
self
.
scp_reader
is
None
or
not
self
.
utt_info
:
return
[]
index
=
index
%
(
len
(
self
.
utt_info
))
inputs
=
[]
labels
=
[]
item_size
=
random
.
randint
(
self
.
min_item_size
,
self
.
max_item_size
)
for
loop_idx
in
range
(
item_size
):
try
:
utt_index
=
random
.
randint
(
0
,
len
(
self
.
utt_info
[
index
][
1
]))
\
%
len
(
self
.
utt_info
[
index
][
1
])
key
=
self
.
utt_info
[
index
][
1
][
utt_index
]
except
:
print
(
index
,
utt_index
,
len
(
self
.
utt_info
[
index
][
1
]))
sys
.
exit
(
-
1
)
x
=
self
.
scp_reader
[
key
]
x
=
np
.
transpose
(
x
)
bg
,
end
=
self
.
_random_chunk
(
x
.
shape
[
-
1
])
inputs
.
append
(
x
[:,
bg
:
end
])
labels
.
append
(
self
.
utt_info
[
index
][
0
])
return
inputs
,
labels
def
_select_by_utt
(
self
,
index
):
if
self
.
scp_reader
is
None
or
len
(
self
.
utt_info
)
==
0
:
return
{}
index
=
index
%
(
len
(
self
.
utt_info
))
key
=
self
.
utt_info
[
index
][
0
]
x
=
self
.
scp_reader
[
key
]
x
=
np
.
transpose
(
x
)
bg
,
end
=
self
.
_random_chunk
(
x
.
shape
[
-
1
])
y
=
self
.
utt_info
[
index
][
1
]
return
[
x
[:,
bg
:
end
]],
[
y
]
def
__getitem__
(
self
,
index
):
if
self
.
_is_select_by_speaker
:
return
self
.
_select_by_speaker
(
index
)
else
:
return
self
.
_select_by_utt
(
index
)
def
__len__
(
self
):
return
len
(
self
.
utt_info
)
*
self
.
repeat
def
__iter__
(
self
):
self
.
_start
=
0
return
self
def
__next__
(
self
):
if
self
.
_start
<
len
(
self
):
ret
=
self
[
self
.
_start
]
self
.
_start
+=
1
return
ret
else
:
raise
StopIteration
paddlespeech/vector/datasets/dataset.py
已删除
100644 → 0
浏览文件 @
0e87037f
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
random
import
numpy
as
np
import
kaldi_python_io
as
k_io
from
paddle.io
import
Dataset
from
paddlespeech.vector.utils.data_utils
import
batch_pad_right
import
paddlespeech.vector.utils
as
utils
from
paddlespeech.vector.utils.utils
import
read_map_file
def
ark_collate_fn
(
batch
):
"""
Custom collate function for kaldi feats dataset
Args:
min_chunk_size: min chunk size of a utterance
max_chunk_size: max chunk size of a utterance
Returns:
ark_collate_fn: collate funtion for dataloader
"""
data
=
[]
target
=
[]
for
items
in
batch
:
for
x
,
y
in
zip
(
items
[
0
],
items
[
1
]):
data
.
append
(
np
.
array
(
x
))
target
.
append
(
y
)
data
,
lengths
=
batch_pad_right
(
data
)
return
np
.
array
(
data
,
dtype
=
np
.
float32
),
\
np
.
array
(
lengths
,
dtype
=
np
.
float32
),
\
np
.
array
(
target
,
dtype
=
np
.
long
).
reshape
((
len
(
target
),
1
))
class
KaldiArkDataset
(
Dataset
):
"""
Dataset used to load kaldi ark/scp files.
"""
def
__init__
(
self
,
scp_file
,
label2utt
,
min_item_size
=
1
,
max_item_size
=
1
,
repeat
=
50
,
min_chunk_size
=
200
,
max_chunk_size
=
400
,
select_by_speaker
=
True
):
self
.
scp_file
=
scp_file
self
.
scp_reader
=
None
self
.
repeat
=
repeat
self
.
min_item_size
=
min_item_size
self
.
max_item_size
=
max_item_size
self
.
min_chunk_size
=
min_chunk_size
self
.
max_chunk_size
=
max_chunk_size
self
.
_collate_fn
=
ark_collate_fn
self
.
_is_select_by_speaker
=
select_by_speaker
if
utils
.
is_exist
(
self
.
scp_file
):
self
.
scp_reader
=
k_io
.
ScriptReader
(
self
.
scp_file
)
label2utts
,
utt2label
=
read_map_file
(
label2utt
,
key_func
=
int
)
self
.
utt_info
=
list
(
label2utts
.
items
())
if
self
.
_is_select_by_speaker
else
list
(
utt2label
.
items
())
@
property
def
collate_fn
(
self
):
"""
Return a collate funtion.
"""
return
self
.
_collate_fn
def
_random_chunk
(
self
,
length
):
chunk_size
=
random
.
randint
(
self
.
min_chunk_size
,
self
.
max_chunk_size
)
if
chunk_size
>=
length
:
return
0
,
length
start
=
random
.
randint
(
0
,
length
-
chunk_size
)
end
=
start
+
chunk_size
return
start
,
end
def
_select_by_speaker
(
self
,
index
):
if
self
.
scp_reader
is
None
or
not
self
.
utt_info
:
return
[]
index
=
index
%
(
len
(
self
.
utt_info
))
inputs
=
[]
labels
=
[]
item_size
=
random
.
randint
(
self
.
min_item_size
,
self
.
max_item_size
)
for
loop_idx
in
range
(
item_size
):
try
:
utt_index
=
random
.
randint
(
0
,
len
(
self
.
utt_info
[
index
][
1
]))
\
%
len
(
self
.
utt_info
[
index
][
1
])
key
=
self
.
utt_info
[
index
][
1
][
utt_index
]
except
:
print
(
index
,
utt_index
,
len
(
self
.
utt_info
[
index
][
1
]))
sys
.
exit
(
-
1
)
x
=
self
.
scp_reader
[
key
]
x
=
np
.
transpose
(
x
)
bg
,
end
=
self
.
_random_chunk
(
x
.
shape
[
-
1
])
inputs
.
append
(
x
[:,
bg
:
end
])
labels
.
append
(
self
.
utt_info
[
index
][
0
])
return
inputs
,
labels
def
_select_by_utt
(
self
,
index
):
if
self
.
scp_reader
is
None
or
len
(
self
.
utt_info
)
==
0
:
return
{}
index
=
index
%
(
len
(
self
.
utt_info
))
key
=
self
.
utt_info
[
index
][
0
]
x
=
self
.
scp_reader
[
key
]
x
=
np
.
transpose
(
x
)
bg
,
end
=
self
.
_random_chunk
(
x
.
shape
[
-
1
])
y
=
self
.
utt_info
[
index
][
1
]
return
[
x
[:,
bg
:
end
]],
[
y
]
def
__getitem__
(
self
,
index
):
if
self
.
_is_select_by_speaker
:
return
self
.
_select_by_speaker
(
index
)
else
:
return
self
.
_select_by_utt
(
index
)
def
__len__
(
self
):
return
len
(
self
.
utt_info
)
*
self
.
repeat
def
__iter__
(
self
):
self
.
_start
=
0
return
self
def
__next__
(
self
):
if
self
.
_start
<
len
(
self
):
ret
=
self
[
self
.
_start
]
self
.
_start
+=
1
return
ret
else
:
raise
StopIteration
return
KaldiArkDataset
paddlespeech/vector/datasets/egs_dataset.py
已删除
100644 → 0
浏览文件 @
0e87037f
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Load nnet3 training egs which generated by kaldi
"""
import
random
import
numpy
as
np
import
kaldi_python_io
as
k_io
from
paddle.io
import
Dataset
import
paddlespeech.vector.utils.utils
as
utils
from
paddlespeech.vector
import
_logger
as
log
class
KaldiEgsDataset
(
Dataset
):
"""
Dataset used to load kaldi nnet3 egs files.
"""
def
__init__
(
self
,
egs_list_file
,
egs_idx
,
transforms
=
None
):
self
.
scp_reader
=
None
self
.
subset_idx
=
egs_idx
-
1
self
.
transforms
=
transforms
if
not
utils
.
is_exist
(
egs_list_file
):
return
self
.
egs_files
=
[]
with
open
(
egs_list_file
,
'r'
)
as
in_fh
:
for
line
in
in_fh
:
if
line
.
strip
():
self
.
egs_files
.
append
(
line
.
strip
())
self
.
next_subset
()
def
next_subset
(
self
,
target_index
=
None
,
delta_index
=
None
):
"""
Use next specific subset
Args:
target_index: target egs index
delta_index: incremental value of egs index
"""
if
self
.
egs_files
:
if
target_index
:
self
.
subset_idx
=
target_index
else
:
delta_index
=
delta_index
if
delta_index
else
1
self
.
subset_idx
+=
delta_index
log
.
info
(
"egs dataset subset index: %d"
%
(
self
.
subset_idx
))
egs_file
=
self
.
egs_files
[
self
.
subset_idx
%
len
(
self
.
egs_files
)]
if
utils
.
is_exist
(
egs_file
):
self
.
scp_reader
=
k_io
.
Nnet3EgsScriptReader
(
egs_file
)
else
:
log
.
warning
(
"No such file or directory: %s"
%
(
egs_file
))
def
__getitem__
(
self
,
index
):
if
self
.
scp_reader
is
None
:
return
{}
index
%=
len
(
self
)
in_dict
,
out_dict
=
self
.
scp_reader
[
index
]
x
=
np
.
array
(
in_dict
[
'matrix'
])
x
=
np
.
transpose
(
x
)
y
=
np
.
array
(
out_dict
[
'matrix'
][
0
][
0
][
0
],
dtype
=
np
.
int
).
reshape
((
1
,))
if
self
.
transforms
is
not
None
:
idx
=
random
.
randint
(
0
,
len
(
self
.
transforms
)
-
1
)
x
=
self
.
transforms
[
idx
](
x
)
return
x
,
y
def
__len__
(
self
):
return
len
(
self
.
scp_reader
)
def
__iter__
(
self
):
self
.
_start
=
0
return
self
def
__next__
(
self
):
if
self
.
_start
<
len
(
self
):
ret
=
self
[
self
.
_start
]
self
.
_start
+=
1
return
ret
else
:
raise
StopIteration
\ No newline at end of file
paddlespeech/vector/utils/data_utils.py
已删除
100755 → 0
浏览文件 @
0e87037f
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
data utilities
"""
import
os
import
sys
import
numpy
import
paddle
def
pad_right_to
(
array
,
target_shape
,
mode
=
"constant"
,
value
=
0
):
"""
This function takes a numpy array of arbitrary shape and pads it to target
shape by appending values on the right.
Args:
array: input numpy array. Input array whose dimension we need to pad.
target_shape : (list, tuple). Target shape we want for the target array its len must be equal to array.ndim
mode : str. Pad mode, please refer to numpy.pad documentation.
value : float. Pad value, please refer to numpy.pad documentation.
Returns:
array: numpy.array. Padded array.
valid_vals : list. List containing proportion for each dimension of original, non-padded values.
"""
assert
len
(
target_shape
)
==
array
.
ndim
pads
=
[]
# this contains the abs length of the padding for each dimension.
valid_vals
=
[]
# thic contains the relative lengths for each dimension.
i
=
0
# iterating over target_shape ndims
while
i
<
len
(
target_shape
):
assert
(
target_shape
[
i
]
>=
array
.
shape
[
i
]
),
"Target shape must be >= original shape for every dim"
pads
.
append
([
0
,
target_shape
[
i
]
-
array
.
shape
[
i
]])
valid_vals
.
append
(
array
.
shape
[
i
]
/
target_shape
[
i
])
i
+=
1
array
=
numpy
.
pad
(
array
,
pads
,
mode
=
mode
,
constant_values
=
value
)
return
array
,
valid_vals
def
batch_pad_right
(
arrays
,
mode
=
"constant"
,
value
=
0
):
"""Given a list of numpy arrays it batches them together by padding to the right
on each dimension in order to get same length for all.
Args:
arrays : list. List of array we wish to pad together.
mode : str. Padding mode see numpy.pad documentation.
value : float. Padding value see numpy.pad documentation.
Returns:
array : numpy.array. Padded array.
valid_vals : list. List containing proportion for each dimension of original, non-padded values.
"""
if
not
len
(
arrays
):
raise
IndexError
(
"arrays list must not be empty"
)
if
len
(
arrays
)
==
1
:
# if there is only one array in the batch we simply unsqueeze it.
return
numpy
.
expand_dims
(
arrays
[
0
],
axis
=
0
),
numpy
.
array
([
1.0
])
if
not
(
any
(
[
arrays
[
i
].
ndim
==
arrays
[
0
].
ndim
for
i
in
range
(
1
,
len
(
arrays
))]
)
):
raise
IndexError
(
"All arrays must have same number of dimensions"
)
# FIXME we limit the support here: we allow padding of only the last dimension
# need to remove this when feat extraction is updated to handle multichannel.
max_shape
=
[]
for
dim
in
range
(
arrays
[
0
].
ndim
):
if
dim
!=
(
arrays
[
0
].
ndim
-
1
):
if
not
all
(
[
x
.
shape
[
dim
]
==
arrays
[
0
].
shape
[
dim
]
for
x
in
arrays
[
1
:]]
):
raise
EnvironmentError
(
"arrays should have same dimensions except for last one"
)
max_shape
.
append
(
max
([
x
.
shape
[
dim
]
for
x
in
arrays
]))
batched
=
[]
valid
=
[]
for
t
in
arrays
:
# for each array we apply pad_right_to
padded
,
valid_percent
=
pad_right_to
(
t
,
max_shape
,
mode
=
mode
,
value
=
value
)
batched
.
append
(
padded
)
valid
.
append
(
valid_percent
[
-
1
])
batched
=
numpy
.
stack
(
batched
)
return
batched
,
numpy
.
array
(
valid
)
def
length_to_mask
(
length
,
max_len
=
None
,
dtype
=
None
):
"""Creates a binary mask for each sequence.
"""
assert
len
(
length
.
shape
)
==
1
if
max_len
is
None
:
max_len
=
paddle
.
cast
(
paddle
.
max
(
length
),
dtype
=
"int64"
)
# using arange to generate mask
mask
=
paddle
.
arange
(
max_len
,
dtype
=
length
.
dtype
).
expand
([
paddle
.
shape
(
length
)[
0
],
max_len
])
<
length
.
unsqueeze
(
1
)
if
dtype
is
None
:
dtype
=
length
.
dtype
mask
=
paddle
.
cast
(
mask
,
dtype
=
dtype
)
return
mask
paddlespeech/vector/utils/utils.py
已删除
100755 → 0
浏览文件 @
0e87037f
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
utilities
"""
import
os
import
sys
import
paddle
import
numpy
as
np
from
paddlespeech.vector
import
_logger
as
log
def
exit_if_not_exist
(
in_path
):
"""
Check the existence of a file or directory, if not exit, exit the program.
Args:
in_path: input dicrector
"""
if
not
is_exist
(
in_path
):
sys
.
exit
(
-
1
)
def
is_exist
(
in_path
):
"""
Check the existence of a file or directory
Args:
in_path: input dicrector
Returns:
True or False
"""
if
not
os
.
path
.
exists
(
in_path
):
log
.
error
(
"No such file or directory: %s"
%
(
in_path
))
return
False
return
True
def
get_latest_file
(
target_dir
):
"""
Get the latest file in target directory
Args:
target_dir: target directory
Returns:
latest_file: a string or None
"""
items
=
os
.
listdir
(
target_dir
)
items
.
sort
(
key
=
lambda
fn
:
os
.
path
.
getmtime
(
os
.
path
.
join
(
target_dir
,
fn
))
\
if
not
os
.
path
.
isdir
(
os
.
path
.
join
(
target_dir
,
fn
))
else
0
)
latest_file
=
None
if
not
items
else
os
.
path
.
join
(
target_dir
,
items
[
-
1
])
return
latest_file
def
avg_models
(
models
):
"""
merge multiple models
"""
checkpoint_dict
=
paddle
.
load
(
models
[
0
])
final_state_dict
=
checkpoint_dict
if
len
(
models
)
>
1
:
for
model
in
models
[
1
:]:
checkpoint_dict
=
paddle
.
load
(
model
)
for
k
,
v
in
checkpoint_dict
.
items
():
final_state_dict
[
k
]
+=
v
for
k
in
final_state_dict
.
keys
():
final_state_dict
[
k
]
/=
float
(
len
(
models
))
if
np
.
any
(
np
.
isnan
(
final_state_dict
[
k
])):
print
(
"Nan in %s"
%
(
k
))
return
final_state_dict
def
Q_from_tokens
(
token_num
):
"""
get prior model, data from uniform, would support others(guassian) in future
"""
freq
=
[
1
]
*
token_num
Q
=
paddle
.
to_tensor
(
freq
,
dtype
=
'float64'
)
return
Q
/
Q
.
sum
()
def
read_map_file
(
map_file
,
key_func
=
None
,
value_func
=
None
,
values_func
=
None
):
""" Read map file. First colume is key, the rest columes are values.
Args:
map_file: map file
key_func: convert function for key
value_func: convert function for each value
values_func: convert function for values
Returns:
dict: key 2 value
dict: value 2 key
"""
if
not
is_exist
(
map_file
):
sys
.
exit
(
0
)
key2val
=
{}
val2key
=
{}
with
open
(
map_file
,
'r'
)
as
f
:
for
line
in
f
:
line
=
line
.
strip
()
if
not
line
:
continue
items
=
line
.
split
()
assert
len
(
items
)
>=
2
key
=
items
[
0
]
if
not
key_func
else
key_func
(
items
[
0
])
values
=
items
[
1
:]
if
not
value_func
else
[
value_func
(
item
)
for
item
in
items
[
1
:]]
if
values_func
:
values
=
values_func
(
values
)
key2val
[
key
]
=
values
for
value
in
values
:
val2key
[
value
]
=
key
return
key2val
,
val2key
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录