提交 48b8cc89 编写于 作者: X xiongxinlei

add score method, test=doc

上级 cfc390e0
...@@ -15,6 +15,7 @@ import argparse ...@@ -15,6 +15,7 @@ import argparse
import os import os
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from typing import Dict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
...@@ -79,7 +80,7 @@ class VectorExecutor(BaseExecutor): ...@@ -79,7 +80,7 @@ class VectorExecutor(BaseExecutor):
"--task", "--task",
type=str, type=str,
default="spk", default="spk",
choices=["spk"], choices=["spk", "score"],
help="task type in vector domain") help="task type in vector domain")
self.parser.add_argument( self.parser.add_argument(
"--input", "--input",
...@@ -147,13 +148,40 @@ class VectorExecutor(BaseExecutor): ...@@ -147,13 +148,40 @@ class VectorExecutor(BaseExecutor):
logger.info(f"task source: {task_source}") logger.info(f"task source: {task_source}")
# stage 3: process the audio one by one # stage 3: process the audio one by one
# we do action according the task type
task_result = OrderedDict() task_result = OrderedDict()
has_exceptions = False has_exceptions = False
for id_, input_ in task_source.items(): for id_, input_ in task_source.items():
try: try:
res = self(input_, model, sample_rate, config, ckpt_path, # extract the speaker audio embedding
device) if parser_args.task == "spk":
task_result[id_] = res logger.info("do vector spk task")
res = self(input_, model, sample_rate, config, ckpt_path,
device)
task_result[id_] = res
elif parser_args.task == "score":
logger.info("do vector score task")
logger.info(f"input content {input_}")
if len(input_.split()) != 2:
logger.error(
f"vector score task input {input_} wav num is not two,"
"that is {len(input_.split())}")
sys.exit(-1)
# get the enroll and test embedding
enroll_audio, test_audio = input_.split()
logger.info(
f"score task, enroll audio: {enroll_audio}, test audio: {test_audio}"
)
enroll_embedding = self(enroll_audio, model, sample_rate,
config, ckpt_path, device)
test_embedding = self(test_audio, model, sample_rate,
config, ckpt_path, device)
# get the score
res = self.get_embeddings_score(enroll_embedding,
test_embedding)
task_result[id_] = res
except Exception as e: except Exception as e:
has_exceptions = True has_exceptions = True
task_result[id_] = f'{e.__class__.__name__}: {e}' task_result[id_] = f'{e.__class__.__name__}: {e}'
...@@ -172,6 +200,49 @@ class VectorExecutor(BaseExecutor): ...@@ -172,6 +200,49 @@ class VectorExecutor(BaseExecutor):
else: else:
return True return True
def _get_job_contents(
self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]:
"""
Read a job input file and return its contents in a dictionary.
Refactor from the Executor._get_job_contents
Args:
job_input (os.PathLike): The job input file.
Returns:
Dict[str, str]: Contents of job input.
"""
job_contents = OrderedDict()
with open(job_input) as f:
for line in f:
line = line.strip()
if not line:
continue
k = line.split(' ')[0]
v = ' '.join(line.split(' ')[1:])
job_contents[k] = v
return job_contents
def get_embeddings_score(self, enroll_embedding, test_embedding):
"""get the enroll embedding and test embedding score
Args:
enroll_embedding (numpy.array): shape: (emb_size), enroll audio embedding
test_embedding (numpy.array): shape: (emb_size), test audio embedding
Returns:
score: the score between enroll embedding and test embedding
"""
if not hasattr(self, "score_func"):
self.score_func = paddle.nn.CosineSimilarity(axis=0)
logger.info("create the cosine score function ")
score = self.score_func(
paddle.to_tensor(enroll_embedding),
paddle.to_tensor(test_embedding))
return score.item()
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,
audio_file: os.PathLike, audio_file: os.PathLike,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册