提交 2c9dc0c8 编写于 作者: X xiongxinlei

add some vector cli comments, test=doc

上级 ef1bc5e8
......@@ -68,12 +68,13 @@ class VectorExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser(
prog="paddlespeech.vector", add_help=True)
self.parser.add_argument(
"--model",
type=str,
default="ecapatdnn_voxceleb12",
choices=["ecapatdnn_voxceleb12"],
help="Choose model type of asr task.")
help="Choose model type of vector task.")
self.parser.add_argument(
"--task",
type=str,
......@@ -81,7 +82,7 @@ class VectorExecutor(BaseExecutor):
choices=["spk"],
help="task type in vector domain")
self.parser.add_argument(
"--input", type=str, default=None, help="Audio file to recognize.")
"--input", type=str, default=None, help="Audio file to extract embedding.")
self.parser.add_argument(
"--sample_rate",
type=int,
......@@ -186,7 +187,7 @@ class VectorExecutor(BaseExecutor):
sample_rate (int, optional): model sample rate. Defaults to 16000.
config (os.PathLike, optional): yaml config. Defaults to None.
ckpt_path (os.PathLike, optional): pretrained model path. Defaults to None.
device (_type_, optional): paddle running host device. Defaults to paddle.get_device().
device (optional): paddle running host device. Defaults to paddle.get_device().
Returns:
dict: return the audio embedding and the embedding shape
......@@ -216,6 +217,7 @@ class VectorExecutor(BaseExecutor):
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""get the neural network path from the pretrained model list
we stored all the pretained mode in the variable `pretrained_models`
Args:
tag (str): model tag in the pretrained model list
......@@ -332,6 +334,7 @@ class VectorExecutor(BaseExecutor):
logger.info(f"embedding size: {embedding.shape}")
# stage 2: put the embedding and dim info to _outputs property
# the embedding type is numpy.array
self._outputs["embedding"] = embedding
def postprocess(self) -> Union[str, os.PathLike]:
......@@ -356,6 +359,7 @@ class VectorExecutor(BaseExecutor):
logger.info(f"Preprocess audio file: {audio_file}")
# stage 1: load the audio sample points
# Note: this process must match the training process
waveform, sr = load_audio(audio_file)
logger.info(f"load the audio sample points, shape is: {waveform.shape}")
......@@ -397,7 +401,7 @@ class VectorExecutor(BaseExecutor):
sample_rate (int): the desired model sample rate
Returns:
bool: return if the audio sample rate matches the model sample rate
bool: return if the audio sample rate matches the model sample rate
"""
self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册