Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
567286ad
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
567286ad
编写于
4月 10, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
wrap the embedding mean and std norm, test=doc
上级
2b4b3e1e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
379 addition
and
70 deletion
+379
-70
paddlespeech/vector/exps/ecapa_tdnn/test.py
paddlespeech/vector/exps/ecapa_tdnn/test.py
+141
-65
paddlespeech/vector/exps/ecapa_tdnn/train.py
paddlespeech/vector/exps/ecapa_tdnn/train.py
+3
-5
paddlespeech/vector/io/dataset.py
paddlespeech/vector/io/dataset.py
+15
-0
paddlespeech/vector/io/embedding_norm.py
paddlespeech/vector/io/embedding_norm.py
+214
-0
paddlespeech/vector/utils/time.py
paddlespeech/vector/utils/time.py
+6
-0
未找到文件。
paddlespeech/vector/exps/ecapa_tdnn/test.py
浏览文件 @
567286ad
...
...
@@ -25,6 +25,7 @@ from paddleaudio.metric import compute_eer
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.io.batch
import
batch_feature_normalize
from
paddlespeech.vector.io.dataset
import
CSVDataset
from
paddlespeech.vector.io.embedding_norm
import
InputNormalization
from
paddlespeech.vector.models.ecapa_tdnn
import
EcapaTdnn
from
paddlespeech.vector.modules.sid_model
import
SpeakerIdetification
from
paddlespeech.vector.training.seeding
import
seed_everything
...
...
@@ -32,6 +33,91 @@ from paddlespeech.vector.training.seeding import seed_everything
logger
=
Log
(
__name__
).
getlog
()
def
compute_dataset_embedding
(
data_loader
,
model
,
mean_var_norm_emb
,
config
,
id2embedding
):
"""compute the dataset embeddings
Args:
data_loader (_type_): _description_
model (_type_): _description_
mean_var_norm_emb (_type_): _description_
config (_type_): _description_
"""
logger
.
info
(
f
'Computing embeddings on
{
data_loader
.
dataset
.
csv_path
}
dataset'
)
with
paddle
.
no_grad
():
for
batch_idx
,
batch
in
enumerate
(
tqdm
(
data_loader
)):
# stage 8-1: extrac the audio embedding
ids
,
feats
,
lengths
=
batch
[
'ids'
],
batch
[
'feats'
],
batch
[
'lengths'
]
embeddings
=
model
.
backbone
(
feats
,
lengths
).
squeeze
(
-
1
)
# (N, emb_size, 1) -> (N, emb_size)
# Global embedding normalization.
# if we use the global embedding norm
# eer can reduece about relative 10%
if
config
.
global_embedding_norm
and
mean_var_norm_emb
:
lengths
=
paddle
.
ones
([
embeddings
.
shape
[
0
]])
embeddings
=
mean_var_norm_emb
(
embeddings
,
lengths
)
# Update embedding dict.
id2embedding
.
update
(
dict
(
zip
(
ids
,
embeddings
)))
def
compute_verification_scores
(
id2embedding
,
train_cohort
,
config
):
labels
=
[]
enroll_ids
=
[]
test_ids
=
[]
logger
.
info
(
f
"read the trial from
{
config
.
verification_file
}
"
)
cos_sim_func
=
paddle
.
nn
.
CosineSimilarity
(
axis
=-
1
)
scores
=
[]
with
open
(
config
.
verification_file
,
'r'
)
as
f
:
for
line
in
f
.
readlines
():
label
,
enroll_id
,
test_id
=
line
.
strip
().
split
(
' '
)
enroll_id
=
enroll_id
.
split
(
'.'
)[
0
].
replace
(
'/'
,
'-'
)
test_id
=
test_id
.
split
(
'.'
)[
0
].
replace
(
'/'
,
'-'
)
labels
.
append
(
int
(
label
))
enroll_emb
=
id2embedding
[
enroll_id
]
test_emb
=
id2embedding
[
test_id
]
score
=
cos_sim_func
(
enroll_emb
,
test_emb
).
item
()
if
"score_norm"
in
config
:
# Getting norm stats for enroll impostors
enroll_rep
=
paddle
.
tile
(
enroll_emb
,
repeat_times
=
[
train_cohort
.
shape
[
0
],
1
])
score_e_c
=
cos_sim_func
(
enroll_rep
,
train_cohort
)
if
"cohort_size"
in
config
:
score_e_c
,
_
=
paddle
.
topk
(
score_e_c
,
k
=
config
.
cohort_size
,
axis
=
0
)
mean_e_c
=
paddle
.
mean
(
score_e_c
,
axis
=
0
)
std_e_c
=
paddle
.
std
(
score_e_c
,
axis
=
0
)
# Getting norm stats for test impostors
test_rep
=
paddle
.
tile
(
test_emb
,
repeat_times
=
[
train_cohort
.
shape
[
0
],
1
])
score_t_c
=
cos_sim_func
(
test_rep
,
train_cohort
)
if
"cohort_size"
in
config
:
score_t_c
,
_
=
paddle
.
topk
(
score_t_c
,
k
=
config
.
cohort_size
,
axis
=
0
)
mean_t_c
=
paddle
.
mean
(
score_t_c
,
axis
=
0
)
std_t_c
=
paddle
.
std
(
score_t_c
,
axis
=
0
)
if
config
.
score_norm
==
"s-norm"
:
score_e
=
(
score
-
mean_e_c
)
/
std_e_c
score_t
=
(
score
-
mean_t_c
)
/
std_t_c
score
=
0.5
*
(
score_e
+
score_t
)
elif
config
.
score_norm
==
"z-norm"
:
score
=
(
score
-
mean_e_c
)
/
std_e_c
elif
config
.
score_norm
==
"t-norm"
:
score
=
(
score
-
mean_t_c
)
/
std_t_c
scores
.
append
(
score
)
return
scores
,
labels
def
main
(
args
,
config
):
# stage0: set the training device, cpu or gpu
paddle
.
set_device
(
args
.
device
)
...
...
@@ -67,7 +153,7 @@ def main(args, config):
hop_length
=
config
.
hop_size
)
enroll_sampler
=
BatchSampler
(
enroll_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
Tru
e
)
# Shuffle to make embedding normalization more robust.
shuffle
=
Fals
e
)
# Shuffle to make embedding normalization more robust.
enroll_loader
=
DataLoader
(
enroll_dataset
,
batch_sampler
=
enroll_sampler
,
collate_fn
=
lambda
x
:
batch_feature_normalize
(
...
...
@@ -83,7 +169,7 @@ def main(args, config):
hop_length
=
config
.
hop_size
)
test_sampler
=
BatchSampler
(
test_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
Tru
e
)
test_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
Fals
e
)
test_loader
=
DataLoader
(
test_dataset
,
batch_sampler
=
test_sampler
,
collate_fn
=
lambda
x
:
batch_feature_normalize
(
...
...
@@ -95,75 +181,65 @@ def main(args, config):
# stage6: global embedding norm to imporve the performance
logger
.
info
(
f
"global embedding norm:
{
config
.
global_embedding_norm
}
"
)
if
config
.
global_embedding_norm
:
global_embedding_mean
=
None
global_embedding_std
=
None
mean_norm_flag
=
config
.
embedding_mean_norm
std_norm_flag
=
config
.
embedding_std_norm
batch_count
=
0
# stage7: Compute embeddings of audios in enrol and test dataset from model.
if
config
.
global_embedding_norm
:
mean_var_norm_emb
=
InputNormalization
(
norm_type
=
"global"
,
mean_norm
=
config
.
embedding_mean_norm
,
std_norm
=
config
.
embedding_std_norm
)
if
"score_norm"
in
config
:
logger
.
info
(
f
"we will do score norm:
{
config
.
score_norm
}
"
)
train_dataset
=
CSVDataset
(
os
.
path
.
join
(
args
.
data_dir
,
"vox/csv/train.csv"
),
feat_type
=
'melspectrogram'
,
n_train_snts
=
config
.
n_train_snts
,
random_chunk
=
False
,
n_mels
=
config
.
n_mels
,
window_size
=
config
.
window_size
,
hop_length
=
config
.
hop_size
)
train_sampler
=
BatchSampler
(
train_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
False
)
train_loader
=
DataLoader
(
train_dataset
,
batch_sampler
=
train_sampler
,
collate_fn
=
lambda
x
:
batch_feature_normalize
(
x
,
mean_norm
=
True
,
std_norm
=
False
),
num_workers
=
config
.
num_workers
,
return_list
=
True
,)
id2embedding
=
{}
# Run multi times to make embedding normalization more stable.
for
i
in
range
(
2
):
for
dl
in
[
enroll_loader
,
test_loader
]:
logger
.
info
(
f
'Loop
{
[
i
+
1
]
}
: Computing embeddings on
{
dl
.
dataset
.
csv_path
}
dataset'
)
with
paddle
.
no_grad
():
for
batch_idx
,
batch
in
enumerate
(
tqdm
(
dl
)):
# stage 8-1: extrac the audio embedding
ids
,
feats
,
lengths
=
batch
[
'ids'
],
batch
[
'feats'
],
batch
[
'lengths'
]
embeddings
=
model
.
backbone
(
feats
,
lengths
).
squeeze
(
-
1
).
numpy
()
# (N, emb_size, 1) -> (N, emb_size)
# Global embedding normalization.
# if we use the global embedding norm
# eer can reduece about relative 10%
if
config
.
global_embedding_norm
:
batch_count
+=
1
current_mean
=
embeddings
.
mean
(
axis
=
0
)
if
mean_norm_flag
else
0
current_std
=
embeddings
.
std
(
axis
=
0
)
if
std_norm_flag
else
1
# Update global mean and std.
if
global_embedding_mean
is
None
and
global_embedding_std
is
None
:
global_embedding_mean
,
global_embedding_std
=
current_mean
,
current_std
else
:
weight
=
1
/
batch_count
# Weight decay by batches.
global_embedding_mean
=
(
1
-
weight
)
*
global_embedding_mean
+
weight
*
current_mean
global_embedding_std
=
(
1
-
weight
)
*
global_embedding_std
+
weight
*
current_std
# Apply global embedding normalization.
embeddings
=
(
embeddings
-
global_embedding_mean
)
/
global_embedding_std
# Update embedding dict.
id2embedding
.
update
(
dict
(
zip
(
ids
,
embeddings
)))
logger
.
info
(
"First loop for enroll and test dataset"
)
compute_dataset_embedding
(
enroll_loader
,
model
,
mean_var_norm_emb
,
config
,
id2embedding
)
compute_dataset_embedding
(
test_loader
,
model
,
mean_var_norm_emb
,
config
,
id2embedding
)
logger
.
info
(
"Second loop for enroll and test dataset"
)
compute_dataset_embedding
(
enroll_loader
,
model
,
mean_var_norm_emb
,
config
,
id2embedding
)
compute_dataset_embedding
(
test_loader
,
model
,
mean_var_norm_emb
,
config
,
id2embedding
)
mean_var_norm_emb
.
save
(
os
.
path
.
join
(
args
.
load_checkpoint
,
"mean_var_norm_emb"
))
# stage 8: Compute cosine scores.
labels
=
[]
enroll_ids
=
[]
test_ids
=
[]
logger
.
info
(
f
"read the trial from
{
config
.
verification_file
}
"
)
with
open
(
config
.
verification_file
,
'r'
)
as
f
:
for
line
in
f
.
readlines
():
label
,
enroll_id
,
test_id
=
line
.
strip
().
split
(
' '
)
labels
.
append
(
int
(
label
))
enroll_ids
.
append
(
enroll_id
.
split
(
'.'
)[
0
].
replace
(
'/'
,
'-'
))
test_ids
.
append
(
test_id
.
split
(
'.'
)[
0
].
replace
(
'/'
,
'-'
))
cos_sim_func
=
paddle
.
nn
.
CosineSimilarity
(
axis
=
1
)
enrol_embeddings
,
test_embeddings
=
map
(
lambda
ids
:
paddle
.
to_tensor
(
np
.
asarray
([
id2embedding
[
uttid
]
for
uttid
in
ids
],
dtype
=
'float32'
)),
[
enroll_ids
,
test_ids
])
# (N, emb_size)
scores
=
cos_sim_func
(
enrol_embeddings
,
test_embeddings
)
train_cohort
=
None
if
"score_norm"
in
config
:
train_embeddings
=
{}
# cohort embedding not do mean and std norm
compute_dataset_embedding
(
train_loader
,
model
,
None
,
config
,
train_embeddings
)
train_cohort
=
paddle
.
stack
(
list
(
train_embeddings
.
values
()))
# compute the scores
scores
,
labels
=
compute_verification_scores
(
id2embedding
,
train_cohort
,
config
)
# compute the EER and threshold
scores
=
paddle
.
to_tensor
(
scores
)
EER
,
threshold
=
compute_eer
(
np
.
asarray
(
labels
),
scores
.
numpy
())
logger
.
info
(
f
'EER of verification test:
{
EER
*
100
:.
4
f
}
%, score threshold:
{
threshold
:.
5
f
}
'
...
...
paddlespeech/vector/exps/ecapa_tdnn/train.py
浏览文件 @
567286ad
...
...
@@ -197,17 +197,15 @@ def main(args, config):
paddle
.
optimizer
.
lr
.
LRScheduler
):
optimizer
.
_learning_rate
.
step
()
optimizer
.
clear_grad
()
train_run_cost
+=
time
.
time
()
-
train_start
# stage 9-8: Calculate average loss per batch
train_misce_start
=
time
.
time
()
avg_loss
=
loss
.
item
()
# stage 9-9: Calculate metrics, which is one-best accuracy
preds
=
paddle
.
argmax
(
logits
,
axis
=
1
)
num_corrects
+=
(
preds
==
labels
).
numpy
().
sum
()
num_samples
+=
feats
.
shape
[
0
]
train_run_cost
+=
time
.
time
()
-
train_start
timer
.
count
()
# step plus one in timer
# stage 9-10: print the log information only on 0-rank per log-freq batchs
...
...
@@ -227,8 +225,8 @@ def main(args, config):
print_msg
+=
' avg_train_cost: {:.5f} sec,'
.
format
(
train_run_cost
/
config
.
log_interval
)
print_msg
+=
' lr={:.4E} step/sec={:.2f} | ETA {}'
.
format
(
lr
,
timer
.
timing
,
timer
.
eta
)
print_msg
+=
' lr={:.4E} step/sec={:.2f}
ips={:.2f}
| ETA {}'
.
format
(
lr
,
timer
.
timing
,
timer
.
ips
,
timer
.
eta
)
logger
.
info
(
print_msg
)
avg_loss
=
0
...
...
paddlespeech/vector/io/dataset.py
浏览文件 @
567286ad
...
...
@@ -65,6 +65,7 @@ class CSVDataset(Dataset):
config
=
None
,
random_chunk
=
True
,
feat_type
:
str
=
"raw"
,
n_train_snts
:
int
=-
1
,
**
kwargs
):
"""Implement the CSV Dataset
...
...
@@ -73,6 +74,9 @@ class CSVDataset(Dataset):
label2id_path (str): the utterance label to integer id map file path
config (CfgNode): yaml config
feat_type (str): dataset feature type. if it is raw, it return pcm data.
n_train_snts (int): select the n_train_snts sample from the dataset.
if n_train_snts = -1, dataset will load all the sample.
Default value is -1.
kwargs : feature type args
"""
super
().
__init__
()
...
...
@@ -81,6 +85,7 @@ class CSVDataset(Dataset):
self
.
config
=
config
self
.
random_chunk
=
random_chunk
self
.
feat_type
=
feat_type
self
.
n_train_snts
=
n_train_snts
self
.
feat_config
=
kwargs
self
.
id2label
=
{}
self
.
label2id
=
{}
...
...
@@ -93,6 +98,9 @@ class CSVDataset(Dataset):
that is audio_id or utt_id, audio duration, segment start point, segment stop point
and utterance label.
Note in training period, the utterance label must has a map to integer id in label2id_path
Returns:
list: the csv data with meta_info type
"""
data
=
[]
...
...
@@ -104,6 +112,10 @@ class CSVDataset(Dataset):
meta_info
(
audio_id
,
float
(
duration
),
wav
,
int
(
start
),
int
(
stop
),
spk_id
))
if
self
.
n_train_snts
>
0
:
sample_num
=
min
(
self
.
n_train_snts
,
len
(
data
))
data
=
data
[
0
:
sample_num
]
return
data
def
load_speaker_to_label
(
self
):
...
...
@@ -173,5 +185,8 @@ class CSVDataset(Dataset):
def
__len__
(
self
):
"""Return the dataset length
Returns:
int: the length num of the dataset
"""
return
len
(
self
.
data
)
paddlespeech/vector/io/embedding_norm.py
0 → 100644
浏览文件 @
567286ad
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Dict
import
paddle
class
InputNormalization
:
spk_dict_mean
:
Dict
[
int
,
paddle
.
Tensor
]
spk_dict_std
:
Dict
[
int
,
paddle
.
Tensor
]
spk_dict_count
:
Dict
[
int
,
int
]
def
__init__
(
self
,
mean_norm
=
True
,
std_norm
=
True
,
norm_type
=
"global"
,
):
"""Do feature or embedding mean and std norm
Args:
mean_norm (bool, optional): mean norm flag. Defaults to True.
std_norm (bool, optional): std norm flag. Defaults to True.
norm_type (str, optional): norm type. Defaults to "global".
"""
super
().
__init__
()
self
.
training
=
True
self
.
mean_norm
=
mean_norm
self
.
std_norm
=
std_norm
self
.
norm_type
=
norm_type
self
.
glob_mean
=
paddle
.
to_tensor
([
0
],
dtype
=
"float32"
)
self
.
glob_std
=
paddle
.
to_tensor
([
0
],
dtype
=
"float32"
)
self
.
spk_dict_mean
=
{}
self
.
spk_dict_std
=
{}
self
.
spk_dict_count
=
{}
self
.
weight
=
1.0
self
.
count
=
0
self
.
eps
=
1e-10
def
__call__
(
self
,
x
,
lengths
,
spk_ids
=
paddle
.
to_tensor
([],
dtype
=
"float32"
)):
"""Returns the tensor with the surrounding context.
Args:
x (paddle.Tensor): A batch of tensors.
lengths (paddle.Tensor): A batch of tensors containing the relative length of each
sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid
computing stats on zero-padded steps.
spk_ids (_type_, optional): tensor containing the ids of each speaker (e.g, [0 10 6]).
It is used to perform per-speaker normalization when
norm_type='speaker'. Defaults to paddle.to_tensor([], dtype="float32").
Returns:
paddle.Tensor: The normalized feature or embedding
"""
N_batches
=
x
.
shape
[
0
]
# print(f"x shape: {x.shape[1]}")
current_means
=
[]
current_stds
=
[]
for
snt_id
in
range
(
N_batches
):
# Avoiding padded time steps
# actual size is the actual time data length
actual_size
=
paddle
.
round
(
lengths
[
snt_id
]
*
x
.
shape
[
1
]).
astype
(
"int32"
)
# computing actual time data statistics
current_mean
,
current_std
=
self
.
_compute_current_stats
(
x
[
snt_id
,
0
:
actual_size
,
...].
unsqueeze
(
0
))
current_means
.
append
(
current_mean
)
current_stds
.
append
(
current_std
)
if
self
.
norm_type
==
"global"
:
current_mean
=
paddle
.
mean
(
paddle
.
stack
(
current_means
),
axis
=
0
)
current_std
=
paddle
.
mean
(
paddle
.
stack
(
current_stds
),
axis
=
0
)
if
self
.
norm_type
==
"global"
:
if
self
.
training
:
if
self
.
count
==
0
:
self
.
glob_mean
=
current_mean
self
.
glob_std
=
current_std
else
:
self
.
weight
=
1
/
(
self
.
count
+
1
)
self
.
glob_mean
=
(
1
-
self
.
weight
)
*
self
.
glob_mean
+
self
.
weight
*
current_mean
self
.
glob_std
=
(
1
-
self
.
weight
)
*
self
.
glob_std
+
self
.
weight
*
current_std
self
.
glob_mean
.
detach
()
self
.
glob_std
.
detach
()
self
.
count
=
self
.
count
+
1
x
=
(
x
-
self
.
glob_mean
)
/
(
self
.
glob_std
)
return
x
def
_compute_current_stats
(
self
,
x
):
"""Returns the tensor with the surrounding context.
Args:
x (paddle.Tensor): A batch of tensors.
Returns:
the statistics of the data
"""
# Compute current mean
if
self
.
mean_norm
:
current_mean
=
paddle
.
mean
(
x
,
axis
=
0
).
detach
()
else
:
current_mean
=
paddle
.
to_tensor
([
0.0
],
dtype
=
"float32"
)
# Compute current std
if
self
.
std_norm
:
current_std
=
paddle
.
std
(
x
,
axis
=
0
).
detach
()
else
:
current_std
=
paddle
.
to_tensor
([
1.0
],
dtype
=
"float32"
)
# Improving numerical stability of std
current_std
=
paddle
.
maximum
(
current_std
,
self
.
eps
*
paddle
.
ones_like
(
current_std
))
return
current_mean
,
current_std
def
_statistics_dict
(
self
):
"""Fills the dictionary containing the normalization statistics.
"""
state
=
{}
state
[
"count"
]
=
self
.
count
state
[
"glob_mean"
]
=
self
.
glob_mean
state
[
"glob_std"
]
=
self
.
glob_std
state
[
"spk_dict_mean"
]
=
self
.
spk_dict_mean
state
[
"spk_dict_std"
]
=
self
.
spk_dict_std
state
[
"spk_dict_count"
]
=
self
.
spk_dict_count
return
state
def
_load_statistics_dict
(
self
,
state
):
"""Loads the dictionary containing the statistics.
Arguments
---------
state : dict
A dictionary containing the normalization statistics.
"""
self
.
count
=
state
[
"count"
]
if
isinstance
(
state
[
"glob_mean"
],
int
):
self
.
glob_mean
=
state
[
"glob_mean"
]
self
.
glob_std
=
state
[
"glob_std"
]
else
:
self
.
glob_mean
=
state
[
"glob_mean"
]
# .to(self.device_inp)
self
.
glob_std
=
state
[
"glob_std"
]
# .to(self.device_inp)
# Loading the spk_dict_mean in the right device
self
.
spk_dict_mean
=
{}
for
spk
in
state
[
"spk_dict_mean"
]:
self
.
spk_dict_mean
[
spk
]
=
state
[
"spk_dict_mean"
][
spk
]
# Loading the spk_dict_std in the right device
self
.
spk_dict_std
=
{}
for
spk
in
state
[
"spk_dict_std"
]:
self
.
spk_dict_std
[
spk
]
=
state
[
"spk_dict_std"
][
spk
]
self
.
spk_dict_count
=
state
[
"spk_dict_count"
]
return
state
def
to
(
self
,
device
):
"""Puts the needed tensors in the right device.
"""
self
=
super
(
InputNormalization
,
self
).
to
(
device
)
self
.
glob_mean
=
self
.
glob_mean
.
to
(
device
)
self
.
glob_std
=
self
.
glob_std
.
to
(
device
)
for
spk
in
self
.
spk_dict_mean
:
self
.
spk_dict_mean
[
spk
]
=
self
.
spk_dict_mean
[
spk
].
to
(
device
)
self
.
spk_dict_std
[
spk
]
=
self
.
spk_dict_std
[
spk
].
to
(
device
)
return
self
def
save
(
self
,
path
):
"""Save statistic dictionary.
Args:
path (str): A path where to save the dictionary.
"""
stats
=
self
.
_statistics_dict
()
paddle
.
save
(
stats
,
path
)
def
_load
(
self
,
path
,
end_of_epoch
=
False
,
device
=
None
):
"""Load statistic dictionary.
Arguments
---------
path : str
The path of the statistic dictionary
device : str, None
Passed to paddle.load(..., map_location=device)
"""
del
end_of_epoch
# Unused here.
stats
=
paddle
.
load
(
path
,
map_location
=
device
)
self
.
_load_statistics_dict
(
stats
)
paddlespeech/vector/utils/time.py
浏览文件 @
567286ad
...
...
@@ -23,6 +23,7 @@ class Timer(object):
self
.
last_start_step
=
0
self
.
current_step
=
0
self
.
_is_running
=
True
self
.
ips
=
0
def
start
(
self
):
self
.
last_time
=
time
.
time
()
...
...
@@ -43,12 +44,17 @@ class Timer(object):
self
.
last_start_step
=
self
.
current_step
time_used
=
time
.
time
()
-
self
.
last_time
self
.
last_time
=
time
.
time
()
self
.
ips
=
run_steps
/
time_used
return
time_used
/
run_steps
@
property
def
is_running
(
self
)
->
bool
:
return
self
.
_is_running
@
property
def
ips
(
self
)
->
float
:
return
self
.
ips
@
property
def
eta
(
self
)
->
str
:
if
not
self
.
is_running
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录