提交 d85d1dee 编写于 作者: X xiongxinlei

exec pre-commit in paddlespeech vector, test=doc

上级 9874fb7d
......@@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
# limitations under the License.
......@@ -13,9 +13,8 @@
# limitations under the License.
import argparse
import os
import time
import numpy as np
import paddle
from yacs.config import CfgNode
......@@ -40,7 +39,8 @@ def extract_audio_embedding(args, config):
ecapa_tdnn = EcapaTdnn(**config.model)
# stage4: build the speaker verification train instance with backbone model
model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=config.num_speakers)
model = SpeakerIdetification(
backbone=ecapa_tdnn, num_class=config.num_speakers)
# stage 2: load the pre-trained model
args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint))
......@@ -62,17 +62,17 @@ def extract_audio_embedding(args, config):
# we need convert the audio feat to one-batch shape [batch, dim, time], where the batch is one
# so the final shape is [1, dim, time]
start_time = time.time()
feat = melspectrogram(x=waveform,
sr=config.sr,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_size)
feat = melspectrogram(
x=waveform,
sr=config.sr,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_size)
feat = paddle.to_tensor(feat).unsqueeze(0)
# in inference period, the lengths is all one without padding
lengths = paddle.ones([1])
feat = feature_normalize(
feat, mean_norm=True, std_norm=False)
feat = feature_normalize(feat, mean_norm=True, std_norm=False)
# model backbone network forward the feats and get the embedding
embedding = model.backbone(
......@@ -80,7 +80,6 @@ def extract_audio_embedding(args, config):
elapsed_time = time.time() - start_time
audio_length = waveform.shape[0] / sr
# stage 5: do global norm with external mean and std
rtf = elapsed_time / audio_length
logger.info(f"{args.device} rft={rtf}")
......
......@@ -13,9 +13,9 @@
# limitations under the License.
import argparse
import os
import time
import numpy as np
import time
import paddle
from paddle.io import BatchSampler
from paddle.io import DataLoader
......@@ -27,6 +27,7 @@ from paddleaudio.datasets.voxceleb import VoxCeleb
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.io.augment import waveform_augment
from paddlespeech.vector.io.batch import batch_pad_right
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.io.batch import waveform_collate_fn
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
......@@ -36,7 +37,6 @@ from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.scheduler import CyclicLRScheduler
from paddlespeech.vector.training.seeding import seed_everything
from paddlespeech.vector.utils.time import Timer
from paddlespeech.vector.io.batch import batch_pad_right
logger = Log(__name__).getlog()
......@@ -145,7 +145,7 @@ def main(args, config):
reader_start = time.time()
for batch_idx, batch in enumerate(train_loader):
train_reader_cost += time.time() - reader_start
# stage 9-1: batch data is audio sample points and speaker id label
feat_start = time.time()
waveforms, labels = batch['waveforms'], batch['labels']
......@@ -165,11 +165,12 @@ def main(args, config):
# stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram
feats = []
for waveform in waveforms.numpy():
feat = melspectrogram(x=waveform,
sr=config.sr,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_size)
feat = melspectrogram(
x=waveform,
sr=config.sr,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_size)
feats.append(feat)
feats = paddle.to_tensor(np.asarray(feats))
......@@ -202,7 +203,7 @@ def main(args, config):
num_corrects += (preds == labels).numpy().sum()
num_samples += feats.shape[0]
timer.count() # step plus one in timer
# stage 9-10: print the log information only on 0-rank per log-freq batchs
if (batch_idx + 1) % config.log_interval == 0 and local_rank == 0:
lr = optimizer.get_lr()
......@@ -213,9 +214,12 @@ def main(args, config):
epoch, config.epochs, batch_idx + 1, steps_per_epoch)
print_msg += ' loss={:.4f}'.format(avg_loss)
print_msg += ' acc={:.4f}'.format(avg_acc)
print_msg += ' avg_reader_cost: {:.5f} sec,'.format(train_reader_cost / config.log_interval)
print_msg += ' avg_feat_cost: {:.5f} sec,'.format(train_feat_cost / config.log_interval)
print_msg += ' avg_train_cost: {:.5f} sec,'.format(train_run_cost / config.log_interval)
print_msg += ' avg_reader_cost: {:.5f} sec,'.format(
train_reader_cost / config.log_interval)
print_msg += ' avg_feat_cost: {:.5f} sec,'.format(
train_feat_cost / config.log_interval)
print_msg += ' avg_train_cost: {:.5f} sec,'.format(
train_run_cost / config.log_interval)
print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format(
lr, timer.timing, timer.eta)
logger.info(print_msg)
......@@ -262,11 +266,12 @@ def main(args, config):
feats = []
for waveform in waveforms.numpy():
feat = melspectrogram(x=waveform,
sr=config.sr,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_size)
feat = melspectrogram(
x=waveform,
sr=config.sr,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_size)
feats.append(feat)
feats = paddle.to_tensor(np.asarray(feats))
......@@ -285,7 +290,8 @@ def main(args, config):
# stage 9-14: Save model parameters
save_dir = os.path.join(args.checkpoint_dir,
'epoch_{}'.format(epoch))
last_saved_epoch = os.path.join('epoch_{}'.format(epoch), "model.pdparams")
last_saved_epoch = os.path.join('epoch_{}'.format(epoch),
"model.pdparams")
logger.info('Saving model checkpoint to {}'.format(save_dir))
paddle.save(model.state_dict(),
os.path.join(save_dir, 'model.pdparams'))
......@@ -300,10 +306,13 @@ def main(args, config):
final_model = os.path.join(args.checkpoint_dir, "model.pdparams")
logger.info(f"we will create the final model: {final_model}")
if os.path.islink(final_model):
logger.info(f"An {final_model} already exists, we will rm is and create it again")
logger.info(
f"An {final_model} already exists, we will rm is and create it again"
)
os.unlink(final_model)
os.symlink(last_saved_epoch, final_model)
if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册