Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
378fe590
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
378fe590
编写于
4月 05, 2022
作者:
C
ccrrong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ami diarization pipeline, test=doc
上级
a2c0fbf2
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
1182 addition
and
5 deletion
+1182
-5
examples/ami/sd0/conf/ecapa_tdnn.yaml
examples/ami/sd0/conf/ecapa_tdnn.yaml
+71
-0
examples/ami/sd0/local/ami_dataset.py
examples/ami/sd0/local/ami_dataset.py
+90
-0
examples/ami/sd0/local/compute_embdding.py
examples/ami/sd0/local/compute_embdding.py
+233
-0
examples/ami/sd0/local/experiment.py
examples/ami/sd0/local/experiment.py
+439
-0
examples/ami/sd0/local/process.sh
examples/ami/sd0/local/process.sh
+49
-0
examples/ami/sd0/run.sh
examples/ami/sd0/run.sh
+31
-5
paddlespeech/vector/cluster/diarization.py
paddlespeech/vector/cluster/diarization.py
+94
-0
utils/compute_der.py
utils/compute_der.py
+175
-0
未找到文件。
examples/ami/sd0/conf/ecapa_tdnn.yaml
0 → 100755
浏览文件 @
378fe590
# ##################################################
# Model: Speaker Diarization Baseline
# Embeddings: Deep embedding
# Clustering Technique: Spectral clustering
# Authors: Nauman Dawalatabad 2020
# #################################################
seed
:
1234
num_speakers
:
7205
###########################################################
# AMI DATA PREPARE SETTING #
###########################################################
split_type
:
'
full_corpus_asr'
skip_TNO
:
True
# Options for mic_type: 'Mix-Lapel', 'Mix-Headset', 'Array1', 'Array1-01', 'BeamformIt'
mic_type
:
'
Mix-Headset'
vad_type
:
'
oracle'
max_subseg_dur
:
3.0
overlap
:
1.5
# Some more exp folders (for cleaner structure).
embedding_dir
:
emb
#!ref <save_folder>/emb
meta_data_dir
:
metadata
#!ref <save_folder>/metadata
ref_rttm_dir
:
ref_rttms
#!ref <save_folder>/ref_rttms
sys_rttm_dir
:
sys_rttms
#!ref <save_folder>/sys_rttms
der_dir
:
DER
#!ref <save_folder>/DER
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
# currently, we only support fbank
sr
:
16000
# sample rate
n_mels
:
80
window_size
:
400
#25ms, sample rate 16000, 25 * 16000 / 1000 = 400
hop_size
:
160
#10ms, sample rate 16000, 10 * 16000 / 1000 = 160
#left_frames: 0
#right_frames: 0
#deltas: False
###########################################################
# MODEL SETTING #
###########################################################
# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml
# if we want use another model, please choose another configuration yaml file
emb_dim
:
192
batch_size
:
16
model
:
input_size
:
80
channels
:
[
1024
,
1024
,
1024
,
1024
,
3072
]
kernel_sizes
:
[
5
,
3
,
3
,
3
,
1
]
dilations
:
[
1
,
2
,
3
,
4
,
1
]
attention_channels
:
128
lin_neurons
:
192
# Will automatically download ECAPA-TDNN model (best).
###########################################################
# SPECTRAL CLUSTERING SETTING #
###########################################################
backend
:
'
SC'
# options: 'kmeans' # Note: kmeans goes only with cos affinity
affinity
:
'
cos'
# options: cos, nn
max_num_spkrs
:
10
oracle_n_spkrs
:
True
###########################################################
# DER EVALUATION SETTING #
###########################################################
ignore_overlap
:
True
forgiveness_collar
:
0.25
examples/ami/sd0/local/ami_dataset.py
0 → 100644
浏览文件 @
378fe590
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
collections
import
json
from
paddle.io
import
Dataset
from
paddleaudio.backends
import
load
as
load_audio
from
paddleaudio.datasets.dataset
import
feat_funcs
class
AMIDataset
(
Dataset
):
"""
AMI dataset.
"""
meta_info
=
collections
.
namedtuple
(
'META_INFO'
,
(
'id'
,
'duration'
,
'wav'
,
'start'
,
'stop'
,
'record_id'
))
def
__init__
(
self
,
json_file
:
str
,
feat_type
:
str
=
'raw'
,
**
kwargs
):
"""
Ags:
json_file (:obj:`str`): Data prep JSON file.
labels (:obj:`List[int]`): Labels of audio files.
feat_type (:obj:`str`, `optional`, defaults to `raw`):
It identifies the feature type that user wants to extrace of an audio file.
"""
if
feat_type
not
in
feat_funcs
.
keys
():
raise
RuntimeError
(
f
"Unknown feat_type:
{
feat_type
}
, it must be one in
{
list
(
feat_funcs
.
keys
())
}
"
)
self
.
json_file
=
json_file
self
.
feat_type
=
feat_type
self
.
feat_config
=
kwargs
self
.
_data
=
self
.
_get_data
()
super
(
AMIDataset
,
self
).
__init__
()
def
_get_data
(
self
):
with
open
(
self
.
json_file
,
"r"
)
as
f
:
meta_data
=
json
.
load
(
f
)
data
=
[]
for
key
in
meta_data
:
sub_seg
=
meta_data
[
key
][
"wav"
]
wav
=
sub_seg
[
"file"
]
duration
=
sub_seg
[
"duration"
]
start
=
sub_seg
[
"start"
]
stop
=
sub_seg
[
"stop"
]
rec_id
=
str
(
key
).
rsplit
(
"_"
,
2
)[
0
]
data
.
append
(
self
.
meta_info
(
str
(
key
),
float
(
duration
),
wav
,
int
(
start
),
int
(
stop
),
str
(
rec_id
)))
return
data
def
_convert_to_record
(
self
,
idx
:
int
):
sample
=
self
.
_data
[
idx
]
record
=
{}
# To show all fields in a namedtuple: `type(sample)._fields`
for
field
in
type
(
sample
).
_fields
:
record
[
field
]
=
getattr
(
sample
,
field
)
waveform
,
sr
=
load_audio
(
record
[
'wav'
])
waveform
=
waveform
[
record
[
'start'
]:
record
[
'stop'
]]
feat_func
=
feat_funcs
[
self
.
feat_type
]
feat
=
feat_func
(
waveform
,
sr
=
sr
,
**
self
.
feat_config
)
if
feat_func
else
waveform
record
.
update
({
'feat'
:
feat
})
return
record
def
__getitem__
(
self
,
idx
):
return
self
.
_convert_to_record
(
idx
)
def
__len__
(
self
):
return
len
(
self
.
_data
)
examples/ami/sd0/local/compute_embdding.py
0 → 100644
浏览文件 @
378fe590
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
json
import
os
import
pickle
import
sys
import
numpy
as
np
import
paddle
from
ami_dataset
import
AMIDataset
from
paddle.io
import
BatchSampler
from
paddle.io
import
DataLoader
from
tqdm.contrib
import
tqdm
from
yacs.config
import
CfgNode
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.cluster.diarization
import
EmbeddingMeta
from
paddlespeech.vector.io.batch
import
batch_feature_normalize
from
paddlespeech.vector.models.ecapa_tdnn
import
EcapaTdnn
from
paddlespeech.vector.modules.sid_model
import
SpeakerIdetification
from
paddlespeech.vector.training.seeding
import
seed_everything
# Logger setup
logger
=
Log
(
__name__
).
getlog
()
def
prepare_subset_json
(
full_meta_data
,
rec_id
,
out_meta_file
):
"""Prepares metadata for a given recording ID.
Arguments
---------
full_meta_data : json
Full meta (json) containing all the recordings
rec_id : str
The recording ID for which meta (json) has to be prepared
out_meta_file : str
Path of the output meta (json) file.
"""
subset
=
{}
for
key
in
full_meta_data
:
k
=
str
(
key
)
if
k
.
startswith
(
rec_id
):
subset
[
key
]
=
full_meta_data
[
key
]
with
open
(
out_meta_file
,
mode
=
"w"
)
as
json_f
:
json
.
dump
(
subset
,
json_f
,
indent
=
2
)
def
create_dataloader
(
json_file
,
batch_size
):
"""Creates the datasets and their data processing pipelines.
This is used for multi-mic processing.
"""
# create datasets
dataset
=
AMIDataset
(
json_file
=
json_file
,
feat_type
=
'melspectrogram'
,
n_mels
=
config
.
n_mels
,
window_size
=
config
.
window_size
,
hop_length
=
config
.
hop_size
)
# create dataloader
batch_sampler
=
BatchSampler
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
True
)
dataloader
=
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
collate_fn
=
lambda
x
:
batch_feature_normalize
(
x
,
mean_norm
=
True
,
std_norm
=
False
),
return_list
=
True
)
return
dataloader
def
main
(
args
,
config
):
# set the training device, cpu or gpu
paddle
.
set_device
(
args
.
device
)
# set the random seed
seed_everything
(
config
.
seed
)
# stage1: build the dnn backbone model network
ecapa_tdnn
=
EcapaTdnn
(
**
config
.
model
)
# stage2: build the speaker verification eval instance with backbone model
model
=
SpeakerIdetification
(
backbone
=
ecapa_tdnn
,
num_class
=
config
.
num_speakers
)
# stage3: load the pre-trained model
# we get the last model from the epoch and save_interval
args
.
load_checkpoint
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
load_checkpoint
))
# load model checkpoint to sid model
state_dict
=
paddle
.
load
(
os
.
path
.
join
(
args
.
load_checkpoint
,
'model.pdparams'
))
model
.
set_state_dict
(
state_dict
)
logger
.
info
(
f
'Checkpoint loaded from
{
args
.
load_checkpoint
}
'
)
# set the model to eval mode
model
.
eval
()
# load meta data
meta_file
=
os
.
path
.
join
(
args
.
data_dir
,
config
.
meta_data_dir
,
"ami_"
+
args
.
dataset
+
"."
+
config
.
mic_type
+
".subsegs.json"
,
)
with
open
(
meta_file
,
"r"
)
as
f
:
full_meta
=
json
.
load
(
f
)
# get all the recording IDs in this dataset.
all_keys
=
full_meta
.
keys
()
A
=
[
word
.
rstrip
().
split
(
"_"
)[
0
]
for
word
in
all_keys
]
all_rec_ids
=
list
(
set
(
A
[
1
:]))
all_rec_ids
.
sort
()
split
=
"AMI_"
+
args
.
dataset
i
=
1
msg
=
"Extra embdding for "
+
args
.
dataset
+
" set"
logger
.
info
(
msg
)
if
len
(
all_rec_ids
)
<=
0
:
msg
=
"No recording IDs found! Please check if meta_data json file is properly generated."
logger
.
error
(
msg
)
sys
.
exit
()
# extra different recordings embdding in a dataset.
for
rec_id
in
tqdm
(
all_rec_ids
):
# This tag will be displayed in the log.
tag
=
(
"["
+
str
(
args
.
dataset
)
+
": "
+
str
(
i
)
+
"/"
+
str
(
len
(
all_rec_ids
))
+
"]"
)
i
=
i
+
1
# log message.
msg
=
"Embdding %s : %s "
%
(
tag
,
rec_id
)
logger
.
debug
(
msg
)
# embedding directory.
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
data_dir
,
config
.
embedding_dir
,
split
)):
os
.
makedirs
(
os
.
path
.
join
(
args
.
data_dir
,
config
.
embedding_dir
,
split
))
# file to store embeddings.
emb_file_name
=
rec_id
+
"."
+
config
.
mic_type
+
".emb_stat.pkl"
diary_stat_emb_file
=
os
.
path
.
join
(
args
.
data_dir
,
config
.
embedding_dir
,
split
,
emb_file_name
)
# prepare a metadata (json) for one recording. This is basically a subset of full_meta.
# lets keep this meta-info in embedding directory itself.
json_file_name
=
rec_id
+
"."
+
config
.
mic_type
+
".json"
meta_per_rec_file
=
os
.
path
.
join
(
args
.
data_dir
,
config
.
embedding_dir
,
split
,
json_file_name
)
# write subset (meta for one recording) json metadata.
prepare_subset_json
(
full_meta
,
rec_id
,
meta_per_rec_file
)
# prepare data loader.
diary_set_loader
=
create_dataloader
(
meta_per_rec_file
,
config
.
batch_size
)
# extract embeddings (skip if already done).
if
not
os
.
path
.
isfile
(
diary_stat_emb_file
):
logger
.
debug
(
"Extracting deep embeddings"
)
embeddings
=
np
.
empty
(
shape
=
[
0
,
config
.
emb_dim
],
dtype
=
np
.
float64
)
segset
=
[]
for
batch_idx
,
batch
in
enumerate
(
tqdm
(
diary_set_loader
)):
# extrac the audio embedding
ids
,
feats
,
lengths
=
batch
[
'ids'
],
batch
[
'feats'
],
batch
[
'lengths'
]
seg
=
[
x
for
x
in
ids
]
segset
=
segset
+
seg
emb
=
model
.
backbone
(
feats
,
lengths
).
squeeze
(
-
1
).
numpy
()
# (N, emb_size, 1) -> (N, emb_size)
embeddings
=
np
.
concatenate
((
embeddings
,
emb
),
axis
=
0
)
segset
=
np
.
array
(
segset
,
dtype
=
"|O"
)
stat_obj
=
EmbeddingMeta
(
segset
=
segset
,
stats
=
embeddings
,
)
logger
.
debug
(
"Saving Embeddings..."
)
with
open
(
diary_stat_emb_file
,
"wb"
)
as
output
:
pickle
.
dump
(
stat_obj
,
output
)
else
:
logger
.
debug
(
"Skipping embedding extraction (as already present)."
)
# Begin experiment!
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
'--device'
,
default
=
"gpu"
,
help
=
"Select which device to perform diarization, defaults to gpu."
)
parser
.
add_argument
(
"--config"
,
default
=
None
,
type
=
str
,
help
=
"configuration file"
)
parser
.
add_argument
(
"--data-dir"
,
default
=
"../save/"
,
type
=
str
,
help
=
"processsed data directory"
)
parser
.
add_argument
(
"--dataset"
,
choices
=
[
'dev'
,
'eval'
],
default
=
"dev"
,
type
=
str
,
help
=
"Select which dataset to extra embdding, defaults to dev"
)
parser
.
add_argument
(
"--load-checkpoint"
,
type
=
str
,
default
=
''
,
help
=
"Directory to load model checkpoint to compute embeddings."
)
args
=
parser
.
parse_args
()
config
=
CfgNode
(
new_allowed
=
True
)
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
config
.
freeze
()
print
(
config
)
main
(
args
,
config
)
examples/ami/sd0/local/experiment.py
0 → 100755
浏览文件 @
378fe590
#!/usr/bin/python3
"""This recipe implements diarization system using deep embedding extraction followed by spectral clustering.
To run this recipe:
> python experiment.py hparams/<your_hyperparams_file.yaml>
e.g., python experiment.py hparams/ecapa_tdnn.yaml
Condition: Oracle VAD (speech regions taken from the groundtruth).
Note: There are multiple ways to write this recipe. We iterate over individual recordings.
This approach is less GPU memory demanding and also makes code easy to understand.
Citation: This recipe is based on the following paper,
N. Dawalatabad, M. Ravanelli, F. Grondin, J. Thienpondt, B. Desplanques, H. Na,
"ECAPA-TDNN Embeddings for Speaker Diarization," arXiv:2104.01466, 2021.
Authors
* Nauman Dawalatabad 2020
"""
import
argparse
import
glob
import
json
import
os
import
pickle
import
shutil
import
sys
import
numpy
as
np
from
tqdm.contrib
import
tqdm
from
yacs.config
import
CfgNode
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.cluster
import
diarization
as
diar
from
utils.compute_der
import
DER
# Logger setup
logger
=
Log
(
__name__
).
getlog
()
def
diarize_dataset
(
full_meta
,
split_type
,
n_lambdas
,
pval
,
save_dir
,
config
,
n_neighbors
=
10
,
):
"""This function diarizes all the recordings in a given dataset. It performs
computation of embedding and clusters them using spectral clustering (or other backends).
The output speaker boundary file is stored in the RTTM format.
"""
# prepare `spkr_info` only once when Oracle num of speakers is selected.
# spkr_info is essential to obtain number of speakers from groundtruth.
if
config
.
oracle_n_spkrs
is
True
:
full_ref_rttm_file
=
os
.
path
.
join
(
save_dir
,
config
.
ref_rttm_dir
,
"fullref_ami_"
+
split_type
+
".rttm"
)
rttm
=
diar
.
read_rttm
(
full_ref_rttm_file
)
spkr_info
=
list
(
# noqa F841
filter
(
lambda
x
:
x
.
startswith
(
"SPKR-INFO"
),
rttm
))
# get all the recording IDs in this dataset.
all_keys
=
full_meta
.
keys
()
A
=
[
word
.
rstrip
().
split
(
"_"
)[
0
]
for
word
in
all_keys
]
all_rec_ids
=
list
(
set
(
A
[
1
:]))
all_rec_ids
.
sort
()
split
=
"AMI_"
+
split_type
i
=
1
# adding tag for directory path.
type_of_num_spkr
=
"oracle"
if
config
.
oracle_n_spkrs
else
"est"
tag
=
(
type_of_num_spkr
+
"_"
+
str
(
config
.
affinity
)
+
"_"
+
config
.
backend
)
# make out rttm dir
out_rttm_dir
=
os
.
path
.
join
(
save_dir
,
config
.
sys_rttm_dir
,
config
.
mic_type
,
split
,
tag
)
if
not
os
.
path
.
exists
(
out_rttm_dir
):
os
.
makedirs
(
out_rttm_dir
)
# diarizing different recordings in a dataset.
for
rec_id
in
tqdm
(
all_rec_ids
):
# this tag will be displayed in the log.
if
rec_id
==
"IS1008a"
:
continue
if
rec_id
==
"ES2011a"
:
continue
tag
=
(
"["
+
str
(
split_type
)
+
": "
+
str
(
i
)
+
"/"
+
str
(
len
(
all_rec_ids
))
+
"]"
)
i
=
i
+
1
# log message.
msg
=
"Diarizing %s : %s "
%
(
tag
,
rec_id
)
logger
.
debug
(
msg
)
# load embeddings.
emb_file_name
=
rec_id
+
"."
+
config
.
mic_type
+
".emb_stat.pkl"
diary_stat_emb_file
=
os
.
path
.
join
(
save_dir
,
config
.
embedding_dir
,
split
,
emb_file_name
)
if
not
os
.
path
.
isfile
(
diary_stat_emb_file
):
msg
=
"Embdding file %s not found! Please check if embdding file is properly generated."
%
(
diary_stat_emb_file
)
logger
.
error
(
msg
)
sys
.
exit
()
with
open
(
diary_stat_emb_file
,
"rb"
)
as
in_file
:
diary_obj
=
pickle
.
load
(
in_file
)
out_rttm_file
=
out_rttm_dir
+
"/"
+
rec_id
+
".rttm"
# processing starts from here.
if
config
.
oracle_n_spkrs
is
True
:
# oracle num of speakers.
num_spkrs
=
diar
.
get_oracle_num_spkrs
(
rec_id
,
spkr_info
)
else
:
if
config
.
affinity
==
"nn"
:
# num of speakers tunned on dev set (only for nn affinity).
num_spkrs
=
n_lambdas
else
:
# num of speakers will be estimated using max eigen gap for cos based affinity.
# so adding None here. Will use this None later-on.
num_spkrs
=
None
if
config
.
backend
==
"kmeans"
:
diar
.
do_kmeans_clustering
(
diary_obj
,
out_rttm_file
,
rec_id
,
num_spkrs
,
pval
,
)
if
config
.
backend
==
"SC"
:
# go for Spectral Clustering (SC).
diar
.
do_spec_clustering
(
diary_obj
,
out_rttm_file
,
rec_id
,
num_spkrs
,
pval
,
config
.
affinity
,
n_neighbors
,
)
# can used for AHC later. Likewise one can add different backends here.
if
config
.
backend
==
"AHC"
:
# call AHC
threshold
=
pval
# pval for AHC is nothing but threshold.
diar
.
do_AHC
(
diary_obj
,
out_rttm_file
,
rec_id
,
num_spkrs
,
threshold
)
# once all RTTM outputs are generated, concatenate individual RTTM files to obtain single RTTM file.
# this is not needed but just staying with the standards.
concate_rttm_file
=
out_rttm_dir
+
"/sys_output.rttm"
logger
.
debug
(
"Concatenating individual RTTM files..."
)
with
open
(
concate_rttm_file
,
"w"
)
as
cat_file
:
for
f
in
glob
.
glob
(
out_rttm_dir
+
"/*.rttm"
):
if
f
==
concate_rttm_file
:
continue
with
open
(
f
,
"r"
)
as
indi_rttm_file
:
shutil
.
copyfileobj
(
indi_rttm_file
,
cat_file
)
msg
=
"The system generated RTTM file for %s set : %s"
%
(
split_type
,
concate_rttm_file
,
)
logger
.
debug
(
msg
)
return
concate_rttm_file
def
dev_pval_tuner
(
full_meta
,
save_dir
,
config
):
"""Tuning p_value for affinity matrix.
The p_value used so that only p% of the values in each row is retained.
"""
DER_list
=
[]
prange
=
np
.
arange
(
0.002
,
0.015
,
0.001
)
n_lambdas
=
None
# using it as flag later.
for
p_v
in
prange
:
# Process whole dataset for value of p_v.
concate_rttm_file
=
diarize_dataset
(
full_meta
,
"dev"
,
n_lambdas
,
p_v
,
save_dir
,
config
)
ref_rttm_file
=
os
.
path
.
join
(
save_dir
,
config
.
ref_rttm_dir
,
"fullref_ami_dev.rttm"
)
sys_rttm_file
=
concate_rttm_file
[
MS
,
FA
,
SER
,
DER_
]
=
DER
(
ref_rttm_file
,
sys_rttm_file
,
config
.
ignore_overlap
,
config
.
forgiveness_collar
,
)
DER_list
.
append
(
DER_
)
if
config
.
oracle_n_spkrs
is
True
and
config
.
backend
==
"kmeans"
:
# no need of p_val search. Note p_val is needed for SC for both oracle and est num of speakers.
# p_val is needed in oracle_n_spkr=False when using kmeans backend.
break
# Take p_val that gave minmum DER on Dev dataset.
tuned_p_val
=
prange
[
DER_list
.
index
(
min
(
DER_list
))]
return
tuned_p_val
def
dev_ahc_threshold_tuner
(
full_meta
,
save_dir
,
config
):
"""Tuning threshold for affinity matrix. This function is called when AHC is used as backend.
"""
DER_list
=
[]
prange
=
np
.
arange
(
0.0
,
1.0
,
0.1
)
n_lambdas
=
None
# using it as flag later.
# Note: p_val is threshold in case of AHC.
for
p_v
in
prange
:
# Process whole dataset for value of p_v.
concate_rttm_file
=
diarize_dataset
(
full_meta
,
"dev"
,
n_lambdas
,
p_v
,
save_dir
,
config
)
ref_rttm
=
os
.
path
.
join
(
save_dir
,
config
.
ref_rttm_dir
,
"fullref_ami_dev.rttm"
)
sys_rttm
=
concate_rttm_file
[
MS
,
FA
,
SER
,
DER_
]
=
DER
(
ref_rttm
,
sys_rttm
,
config
.
ignore_overlap
,
config
.
forgiveness_collar
,
)
DER_list
.
append
(
DER_
)
if
config
.
oracle_n_spkrs
is
True
:
break
# no need of threshold search.
# Take p_val that gave minmum DER on Dev dataset.
tuned_p_val
=
prange
[
DER_list
.
index
(
min
(
DER_list
))]
return
tuned_p_val
def
dev_nn_tuner
(
full_meta
,
split_type
,
save_dir
,
config
):
"""Tuning n_neighbors on dev set. Assuming oracle num of speakers.
This is used when nn based affinity is selected.
"""
DER_list
=
[]
pval
=
None
# Now assumming oracle num of speakers.
n_lambdas
=
4
for
nn
in
range
(
5
,
15
):
# Process whole dataset for value of n_lambdas.
concate_rttm_file
=
diarize_dataset
(
full_meta
,
"dev"
,
n_lambdas
,
p_v
,
save_dir
,
config
,
nn
)
ref_rttm
=
os
.
path
.
join
(
save_dir
,
config
.
ref_rttm_dir
,
"fullref_ami_dev.rttm"
)
sys_rttm
=
concate_rttm_file
[
MS
,
FA
,
SER
,
DER_
]
=
DER
(
ref_rttm
,
sys_rttm
,
config
.
ignore_overlap
,
config
.
forgiveness_collar
,
)
DER_list
.
append
([
nn
,
DER_
])
if
config
.
oracle_n_spkrs
is
True
and
config
.
backend
==
"kmeans"
:
break
DER_list
.
sort
(
key
=
lambda
x
:
x
[
1
])
tunned_nn
=
DER_list
[
0
]
return
tunned_nn
[
0
]
def
dev_tuner
(
full_meta
,
split_type
,
save_dir
,
config
):
"""Tuning n_components on dev set. Used for nn based affinity matrix.
Note: This is a very basic tunning for nn based affinity.
This is work in progress till we find a better way.
"""
DER_list
=
[]
pval
=
None
for
n_lambdas
in
range
(
1
,
config
.
max_num_spkrs
+
1
):
# Process whole dataset for value of n_lambdas.
concate_rttm_file
=
diarize_dataset
(
full_meta
,
"dev"
,
n_lambdas
,
p_v
,
save_dir
,
config
)
ref_rttm
=
os
.
path
.
join
(
save_dir
,
config
.
ref_rttm_dir
,
"fullref_ami_dev.rttm"
)
sys_rttm
=
concate_rttm_file
[
MS
,
FA
,
SER
,
DER_
]
=
DER
(
ref_rttm
,
sys_rttm
,
config
.
ignore_overlap
,
config
.
forgiveness_collar
,
)
DER_list
.
append
(
DER_
)
# Take n_lambdas with minmum DER.
tuned_n_lambdas
=
DER_list
.
index
(
min
(
DER_list
))
+
1
return
tuned_n_lambdas
def
main
(
args
,
config
):
# AMI Dev Set: Tune hyperparams on dev set.
# Read the embdding file for dev set generated during embdding compute
dev_meta_file
=
os
.
path
.
join
(
args
.
data_dir
,
config
.
meta_data_dir
,
"ami_dev."
+
config
.
mic_type
+
".subsegs.json"
,
)
with
open
(
dev_meta_file
,
"r"
)
as
f
:
meta_dev
=
json
.
load
(
f
)
full_meta
=
meta_dev
# Processing starts from here
# Following few lines selects option for different backend and affinity matrices. Finds best values for hyperameters using dev set.
ref_rttm_file
=
os
.
path
.
join
(
args
.
data_dir
,
config
.
ref_rttm_dir
,
"fullref_ami_dev.rttm"
)
best_nn
=
None
if
config
.
affinity
==
"nn"
:
logger
.
info
(
"Tuning for nn (Multiple iterations over AMI Dev set)"
)
best_nn
=
dev_nn_tuner
(
full_meta
,
args
.
data_dir
,
config
)
n_lambdas
=
None
best_pval
=
None
if
config
.
affinity
==
"cos"
and
(
config
.
backend
==
"SC"
or
config
.
backend
==
"kmeans"
):
# oracle num_spkrs or not, doesn't matter for kmeans and SC backends
# cos: Tune for the best pval for SC /kmeans (for unknown num of spkrs)
logger
.
info
(
"Tuning for p-value for SC (Multiple iterations over AMI Dev set)"
)
best_pval
=
dev_pval_tuner
(
full_meta
,
args
.
data_dir
,
config
)
elif
config
.
backend
==
"AHC"
:
logger
.
info
(
"Tuning for threshold-value for AHC"
)
best_threshold
=
dev_ahc_threshold_tuner
(
full_meta
,
args
.
data_dir
,
config
)
best_pval
=
best_threshold
else
:
# NN for unknown num of speakers (can be used in future)
if
config
.
oracle_n_spkrs
is
False
:
# nn: Tune num of number of components (to be updated later)
logger
.
info
(
"Tuning for number of eigen components for NN (Multiple iterations over AMI Dev set)"
)
# dev_tuner used for tuning num of components in NN. Can be used in future.
n_lambdas
=
dev_tuner
(
full_meta
,
args
.
data_dir
,
config
)
# load 'dev' and 'eval' metadata files.
full_meta_dev
=
full_meta
# current full_meta is for 'dev'
eval_meta_file
=
os
.
path
.
join
(
args
.
data_dir
,
config
.
meta_data_dir
,
"ami_eval."
+
config
.
mic_type
+
".subsegs.json"
,
)
with
open
(
eval_meta_file
,
"r"
)
as
f
:
full_meta_eval
=
json
.
load
(
f
)
# tag to be appended to final output DER files. Writing DER for individual files.
type_of_num_spkr
=
"oracle"
if
config
.
oracle_n_spkrs
else
"est"
tag
=
(
type_of_num_spkr
+
"_"
+
str
(
config
.
affinity
)
+
"."
+
config
.
mic_type
)
# perform final diarization on 'dev' and 'eval' with best hyperparams.
final_DERs
=
{}
out_der_dir
=
os
.
path
.
join
(
args
.
data_dir
,
config
.
der_dir
)
if
not
os
.
path
.
exists
(
out_der_dir
):
os
.
makedirs
(
out_der_dir
)
for
split_type
in
[
"dev"
,
"eval"
]:
if
split_type
==
"dev"
:
full_meta
=
full_meta_dev
else
:
full_meta
=
full_meta_eval
# performing diarization.
msg
=
"Diarizing using best hyperparams: "
+
split_type
+
" set"
logger
.
info
(
msg
)
out_boundaries
=
diarize_dataset
(
full_meta
,
split_type
,
n_lambdas
=
n_lambdas
,
pval
=
best_pval
,
n_neighbors
=
best_nn
,
save_dir
=
args
.
data_dir
,
config
=
config
)
# computing DER.
msg
=
"Computing DERs for "
+
split_type
+
" set"
logger
.
info
(
msg
)
ref_rttm
=
os
.
path
.
join
(
args
.
data_dir
,
config
.
ref_rttm_dir
,
"fullref_ami_"
+
split_type
+
".rttm"
)
sys_rttm
=
out_boundaries
[
MS
,
FA
,
SER
,
DER_vals
]
=
DER
(
ref_rttm
,
sys_rttm
,
config
.
ignore_overlap
,
config
.
forgiveness_collar
,
individual_file_scores
=
True
,
)
# writing DER values to a file. Append tag.
der_file_name
=
split_type
+
"_DER_"
+
tag
out_der_file
=
os
.
path
.
join
(
out_der_dir
,
der_file_name
)
msg
=
"Writing DER file to: "
+
out_der_file
logger
.
info
(
msg
)
diar
.
write_ders_file
(
ref_rttm
,
DER_vals
,
out_der_file
)
msg
=
(
"AMI "
+
split_type
+
" set DER = %s %%
\n
"
%
(
str
(
round
(
DER_vals
[
-
1
],
2
))))
logger
.
info
(
msg
)
final_DERs
[
split_type
]
=
round
(
DER_vals
[
-
1
],
2
)
# final print DERs
msg
=
(
"Final Diarization Error Rate (%%) on AMI corpus: Dev = %s %% | Eval = %s %%
\n
"
%
(
str
(
final_DERs
[
"dev"
]),
str
(
final_DERs
[
"eval"
])))
logger
.
info
(
msg
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--config"
,
default
=
None
,
type
=
str
,
help
=
"configuration file"
)
parser
.
add_argument
(
"--data-dir"
,
default
=
"../data/"
,
type
=
str
,
help
=
"processsed data directory"
)
args
=
parser
.
parse_args
()
config
=
CfgNode
(
new_allowed
=
True
)
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
config
.
freeze
()
print
(
config
)
main
(
args
,
config
)
examples/ami/sd0/local/process.sh
0 → 100755
浏览文件 @
378fe590
#!/bin/bash
stage
=
2
set
=
L
.
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
set
-u
set
-o
pipefail
data_folder
=
$1
manual_annot_folder
=
$2
save_folder
=
$3
pretrained_model_dir
=
$4
conf_path
=
$5
ref_rttm_dir
=
${
save_folder
}
/ref_rttms
meta_data_dir
=
${
save_folder
}
/metadata
if
[
${
stage
}
-le
0
]
;
then
echo
"AMI Data preparation"
python
local
/ami_prepare.py
--data_folder
${
data_folder
}
\
--manual_annot_folder
${
manual_annot_folder
}
\
--save_folder
${
save_folder
}
--ref_rttm_dir
${
ref_rttm_dir
}
\
--meta_data_dir
${
meta_data_dir
}
if
[
$?
-ne
0
]
;
then
echo
"Prepare AMI failed. Please check log message."
exit
1
fi
echo
"AMI data preparation done."
fi
if
[
${
stage
}
-le
1
]
;
then
# extra embddings for dev and eval dataset
for
name
in
dev
eval
;
do
python
local
/compute_embdding.py
--config
${
conf_path
}
\
--data-dir
${
save_folder
}
\
--device
gpu:0
\
--dataset
${
name
}
\
--load-checkpoint
${
pretrained_model_dir
}
done
fi
if
[
${
stage
}
-le
2
]
;
then
# tune hyperparams on dev set
# perform final diarization on 'dev' and 'eval' with best hyperparams
python
local
/experiment.py
--config
${
conf_path
}
\
--data-dir
${
save_folder
}
fi
examples/ami/sd0/run.sh
浏览文件 @
378fe590
#!/bin/bash
#!/bin/bash
.
path.sh
||
exit
1
;
.
./
path.sh
||
exit
1
;
set
-e
set
-e
stage
=
1
stage
=
1
stop_stage
=
50
#TARGET_DIR=${MAIN_ROOT}/dataset/ami
TARGET_DIR
=
/home/dataset/AMI
data_folder
=
${
TARGET_DIR
}
/amicorpus
#e.g., /path/to/amicorpus/
manual_annot_folder
=
${
TARGET_DIR
}
/ami_public_manual_1.6.2
#e.g., /path/to/ami_public_manual_1.6.2/
save_folder
=
./save
pretraind_model_dir
=
${
save_folder
}
/model
conf_path
=
conf/ecapa_tdnn.yaml
.
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
.
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
if
[
${
stage
}
-le
1
]
;
then
if
[
$stage
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
# prepare data
# Prepare data and model
bash ./local/data.sh
||
exit
-1
# Download AMI corpus, You need around 10GB of free space to get whole data
# The signals are too large to package in this way,
# so you need to use the chooser to indicate which ones you wish to download
echo
"Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data."
echo
"Annotations: AMI manual annotations v1.6.2 "
echo
"Signals: "
echo
"1) Select one or more AMI meetings: the IDs please follow ./ami_split.py"
echo
"2) Select media streams: Just select Headset mix"
# Download the pretrained Model from HuggingFace or other pretrained model
echo
"Please download the pretrained ECAPA-TDNN Model and put the pretrainde model in given path: "
${
pretraind_model_dir
}
fi
fi
if
[
$stage
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# Tune hyperparams on dev set and perform final diarization on dev and eval with best hyperparams.
bash ./local/process.sh
${
data_folder
}
${
manual_annot_folder
}
${
save_folder
}
${
pretraind_model_dir
}
${
conf_path
}
||
exit
1
fi
paddlespeech/vector/cluster/diarization.py
浏览文件 @
378fe590
...
@@ -746,6 +746,77 @@ def merge_ssegs_same_speaker(lol):
...
@@ -746,6 +746,77 @@ def merge_ssegs_same_speaker(lol):
return
new_lol
return
new_lol
def
write_ders_file
(
ref_rttm
,
DER
,
out_der_file
):
"""Write the final DERs for individual recording.
Arguments
---------
ref_rttm : str
Reference RTTM file.
DER : array
Array containing DER values of each recording.
out_der_file : str
File to write the DERs.
"""
rttm
=
read_rttm
(
ref_rttm
)
spkr_info
=
list
(
filter
(
lambda
x
:
x
.
startswith
(
"SPKR-INFO"
),
rttm
))
rec_id_list
=
[]
count
=
0
with
open
(
out_der_file
,
"w"
)
as
f
:
for
row
in
spkr_info
:
a
=
row
.
split
(
" "
)
rec_id
=
a
[
1
]
if
rec_id
not
in
rec_id_list
:
r
=
[
rec_id
,
str
(
round
(
DER
[
count
],
2
))]
rec_id_list
.
append
(
rec_id
)
line_str
=
" "
.
join
(
r
)
f
.
write
(
"%s
\n
"
%
line_str
)
count
+=
1
r
=
[
"OVERALL "
,
str
(
round
(
DER
[
count
],
2
))]
line_str
=
" "
.
join
(
r
)
f
.
write
(
"%s
\n
"
%
line_str
)
def
get_oracle_num_spkrs
(
rec_id
,
spkr_info
):
"""
Returns actual number of speakers in a recording from the ground-truth.
This can be used when the condition is oracle number of speakers.
Arguments
---------
rec_id : str
Recording ID for which the number of speakers have to be obtained.
spkr_info : list
Header of the RTTM file. Starting with `SPKR-INFO`.
Example
-------
>>> from speechbrain.processing import diarization as diar
>>> spkr_info = ['SPKR-INFO ES2011a 0 <NA> <NA> <NA> unknown ES2011a.A <NA> <NA>',
... 'SPKR-INFO ES2011a 0 <NA> <NA> <NA> unknown ES2011a.B <NA> <NA>',
... 'SPKR-INFO ES2011a 0 <NA> <NA> <NA> unknown ES2011a.C <NA> <NA>',
... 'SPKR-INFO ES2011a 0 <NA> <NA> <NA> unknown ES2011a.D <NA> <NA>',
... 'SPKR-INFO ES2011b 0 <NA> <NA> <NA> unknown ES2011b.A <NA> <NA>',
... 'SPKR-INFO ES2011b 0 <NA> <NA> <NA> unknown ES2011b.B <NA> <NA>',
... 'SPKR-INFO ES2011b 0 <NA> <NA> <NA> unknown ES2011b.C <NA> <NA>']
>>> diar.get_oracle_num_spkrs('ES2011a', spkr_info)
4
>>> diar.get_oracle_num_spkrs('ES2011b', spkr_info)
3
"""
num_spkrs
=
0
for
line
in
spkr_info
:
if
rec_id
in
line
:
# Since rec_id is prefix for each speaker
num_spkrs
+=
1
return
num_spkrs
def
distribute_overlap
(
lol
):
def
distribute_overlap
(
lol
):
"""
"""
Distributes the overlapped speech equally among the adjacent segments
Distributes the overlapped speech equally among the adjacent segments
...
@@ -826,6 +897,29 @@ def distribute_overlap(lol):
...
@@ -826,6 +897,29 @@ def distribute_overlap(lol):
return
new_lol
return
new_lol
def
read_rttm
(
rttm_file_path
):
"""
Reads and returns RTTM in list format.
Arguments
---------
rttm_file_path : str
Path to the RTTM file to be read.
Returns
-------
rttm : list
List containing rows of RTTM file.
"""
rttm
=
[]
with
open
(
rttm_file_path
,
"r"
)
as
f
:
for
line
in
f
:
entry
=
line
[:
-
1
]
rttm
.
append
(
entry
)
return
rttm
def
write_rttm
(
segs_list
,
out_rttm_file
):
def
write_rttm
(
segs_list
,
out_rttm_file
):
"""
"""
Writes the segment list in RTTM format (A standard NIST format).
Writes the segment list in RTTM format (A standard NIST format).
...
...
utils/compute_der.py
0 → 100755
浏览文件 @
378fe590
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Calculates Diarization Error Rate (DER) which is the sum of Missed Speaker (MS),
False Alarm (FA), and Speaker Error Rate (SER) using md-eval-22.pl from NIST RT Evaluation.
Credits
This code is adapted from https://github.com/speechbrain/speechbrain
"""
import
argparse
import
os
import
re
import
subprocess
import
numpy
as
np
FILE_IDS
=
re
.
compile
(
r
"(?<=Speaker Diarization for).+(?=\*\*\*)"
)
SCORED_SPEAKER_TIME
=
re
.
compile
(
r
"(?<=SCORED SPEAKER TIME =)[\d.]+"
)
MISS_SPEAKER_TIME
=
re
.
compile
(
r
"(?<=MISSED SPEAKER TIME =)[\d.]+"
)
FA_SPEAKER_TIME
=
re
.
compile
(
r
"(?<=FALARM SPEAKER TIME =)[\d.]+"
)
ERROR_SPEAKER_TIME
=
re
.
compile
(
r
"(?<=SPEAKER ERROR TIME =)[\d.]+"
)
def
rectify
(
arr
):
"""Corrects corner cases and converts scores into percentage.
"""
# Numerator and denominator both 0.
arr
[
np
.
isnan
(
arr
)]
=
0
# Numerator > 0, but denominator = 0.
arr
[
np
.
isinf
(
arr
)]
=
1
arr
*=
100.0
return
arr
def
DER
(
ref_rttm
,
sys_rttm
,
ignore_overlap
=
False
,
collar
=
0.25
,
individual_file_scores
=
False
,
):
"""Computes Missed Speaker percentage (MS), False Alarm (FA),
Speaker Error Rate (SER), and Diarization Error Rate (DER).
Arguments
---------
ref_rttm : str
The path of reference/groundtruth RTTM file.
sys_rttm : str
The path of the system generated RTTM file.
individual_file_scores : bool
If True, returns scores for each file in order.
collar : float
Forgiveness collar.
ignore_overlap : bool
If True, ignores overlapping speech during evaluation.
Returns
-------
MS : float array
Missed Speech.
FA : float array
False Alarms.
SER : float array
Speaker Error Rates.
DER : float array
Diarization Error Rates.
"""
curr
=
os
.
path
.
abspath
(
os
.
path
.
dirname
(
__file__
))
mdEval
=
os
.
path
.
join
(
curr
,
"./md-eval.pl"
)
cmd
=
[
mdEval
,
"-af"
,
"-r"
,
ref_rttm
,
"-s"
,
sys_rttm
,
"-c"
,
str
(
collar
),
]
print
(
cmd
)
if
ignore_overlap
:
cmd
.
append
(
"-1"
)
try
:
stdout
=
subprocess
.
check_output
(
cmd
,
stderr
=
subprocess
.
STDOUT
)
except
subprocess
.
CalledProcessError
as
ex
:
stdout
=
ex
.
output
else
:
stdout
=
stdout
.
decode
(
"utf-8"
)
# Get all recording IDs
file_ids
=
[
m
.
strip
()
for
m
in
FILE_IDS
.
findall
(
stdout
)]
file_ids
=
[
file_id
[
2
:]
if
file_id
.
startswith
(
"f="
)
else
file_id
for
file_id
in
file_ids
]
scored_speaker_times
=
np
.
array
(
[
float
(
m
)
for
m
in
SCORED_SPEAKER_TIME
.
findall
(
stdout
)])
miss_speaker_times
=
np
.
array
(
[
float
(
m
)
for
m
in
MISS_SPEAKER_TIME
.
findall
(
stdout
)])
fa_speaker_times
=
np
.
array
(
[
float
(
m
)
for
m
in
FA_SPEAKER_TIME
.
findall
(
stdout
)])
error_speaker_times
=
np
.
array
(
[
float
(
m
)
for
m
in
ERROR_SPEAKER_TIME
.
findall
(
stdout
)])
with
np
.
errstate
(
invalid
=
"ignore"
,
divide
=
"ignore"
):
tot_error_times
=
(
miss_speaker_times
+
fa_speaker_times
+
error_speaker_times
)
miss_speaker_frac
=
miss_speaker_times
/
scored_speaker_times
fa_speaker_frac
=
fa_speaker_times
/
scored_speaker_times
sers_frac
=
error_speaker_times
/
scored_speaker_times
ders_frac
=
tot_error_times
/
scored_speaker_times
# Values in percentage of scored_speaker_time
miss_speaker
=
rectify
(
miss_speaker_frac
)
fa_speaker
=
rectify
(
fa_speaker_frac
)
sers
=
rectify
(
sers_frac
)
ders
=
rectify
(
ders_frac
)
if
individual_file_scores
:
return
miss_speaker
,
fa_speaker
,
sers
,
ders
else
:
return
miss_speaker
[
-
1
],
fa_speaker
[
-
1
],
sers
[
-
1
],
ders
[
-
1
]
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Compute DER"
)
parser
.
add_argument
(
"--ref_rttm"
,
type
=
str
,
help
=
"the path of reference/groundtruth RTTM file"
)
parser
.
add_argument
(
"--sys_rttm"
,
type
=
str
,
help
=
"the path of the system generated RTTM file."
)
parser
.
add_argument
(
"--individual_file_scores"
,
type
=
bool
,
help
=
"whether returns scores for each file in order."
)
parser
.
add_argument
(
"--collar"
,
type
=
float
,
help
=
"forgiveness collar."
)
parser
.
add_argument
(
"--ignore_overlap"
,
type
=
bool
,
help
=
"whether ignores overlapping speech during evaluation."
)
args
=
parser
.
parse_args
()
Scores
=
DER
(
args
.
ref_rttm
,
args
.
sys_rttm
,
args
.
ignore_overlap
,
args
.
collar
,
args
.
individual_file_scores
)
print
(
Scores
)
if
__name__
==
"__main__"
:
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录