未验证 提交 14214566 编写于 作者: R ranchlai 提交者: GitHub

update example to use new functionals (#5330)

上级 4f1462d7
......@@ -26,7 +26,7 @@ import paddle
import paddleaudio
import yaml
from paddle.io import DataLoader, Dataset, IterableDataset
from paddleaudio import augment
from paddleaudio.utils import augments
from utils import get_labels, get_ytid_clsidx_mapping
......@@ -108,12 +108,14 @@ class H5AudioSet(Dataset):
x = x[:, :target_len]
if self.training and self.augment:
x = augment.random_crop2d(x,
self.config['mel_crop_len'],
tempo_axis=1)
x = augments.random_crop2d(x,
self.config['mel_crop_len'],
tempo_axis=1)
x = spect_permute(x, tempo_axis=1, nblocks=random_choice([0, 2, 3]))
aug_level = random_choice([0.2, 0.1, 0])
x = augment.adaptive_spect_augment(x, tempo_axis=1, level=aug_level)
x = augments.adaptive_spect_augment(x,
tempo_axis=1,
level=aug_level)
return x.T
def __getitem__(self, idx):
......
......@@ -24,10 +24,11 @@ from dataset import get_val_loader
from model import resnet50
from paddle.utils import download
from sklearn.metrics import average_precision_score, roc_auc_score
from utils import compute_dprime,download_assets
from utils import compute_dprime, download_assets
checkpoint_url = 'https://bj.bcebos.com/paddleaudio/paddleaudio/resnet50_weight_averaging_mAP0.416.pdparams'
def evaluate(epoch, val_loader, model, loss_fn):
model.eval()
avg_loss = 0.0
......
......@@ -22,6 +22,7 @@ import paddleaudio as pa
import yaml
from model import resnet50
from paddle.utils import download
from paddleaudio.functional import melspectrogram
from utils import (download_assets, get_label_name_mapping, get_labels,
get_metrics)
......@@ -32,22 +33,22 @@ checkpoint_url = 'https://bj.bcebos.com/paddleaudio/paddleaudio/resnet50_weight_
def load_and_extract_feature(file, c):
s, _ = pa.load(file, sr=c['sample_rate'])
x = pa.features.melspectrogram(s,
sr=c['sample_rate'],
window_size=c['window_size'],
hop_length=c['hop_size'],
n_mels=c['mel_bins'],
fmin=c['fmin'],
fmax=c['fmax'],
window='hann',
center=True,
pad_mode='reflect',
ref=1.0,
amin=1e-10,
top_db=None)
x = x.T # !!
x = paddle.Tensor(x).unsqueeze((0, 1))
x = melspectrogram(paddle.to_tensor(s),
sr=c['sample_rate'],
win_length=c['window_size'],
n_fft=c['window_size'],
hop_length=c['hop_size'],
n_mels=c['mel_bins'],
f_min=c['fmin'],
f_max=c['fmax'],
window='hann',
center=True,
pad_mode='reflect',
to_db=True,
amin=1e-3,
top_db=None)
x = x.transpose((0, 2, 1))
x = x.unsqueeze((0, ))
return x
......
......@@ -129,7 +129,7 @@ if __name__ == '__main__':
model.train()
model.clear_gradients()
t0 = time.time()
for batch_id, (x,y) in enumerate(train_loader()):
for batch_id, (x, y) in enumerate(train_loader()):
if step < warm_steps:
optimizer.set_lr(lrs[step])
x.stop_gradient = False
......@@ -215,4 +215,4 @@ if __name__ == '__main__':
else:
factor = 0.8
optimizer.set_lr(optimizer.get_lr() * factor)
print('decreased lr to {}'.format(optimizer.get_lr()))
\ No newline at end of file
print('decreased lr to {}'.format(optimizer.get_lr()))
......@@ -4,8 +4,10 @@ import os
import h5py
import numpy as np
import paddle
import paddleaudio as pa
import tqdm
from paddleaudio.functional import melspectrogram
parser = argparse.ArgumentParser(description='wave2mel')
parser.add_argument('--wav_file', type=str, required=False, default='')
......@@ -64,20 +66,23 @@ if len(h5_files) > 0:
s = src_h5[key][:]
s = pa.depth_convert(s, 'float32')
# s = pa.resample(s,32000,args.sample_rate)
x = pa.features.melspectrogram(s,
sr=args.sample_rate,
window_size=args.window_size,
hop_length=args.hop_length,
n_mels=args.mel_bins,
fmin=args.fmin,
fmax=args.fmax,
window='hann',
center=True,
pad_mode='reflect',
ref=1.0,
amin=1e-10,
top_db=None)
dst_h5.create_dataset(key, data=x)
x = melspectrogram(paddle.to_tensor(s),
sr=args.sample_rate,
win_length=args.window_size,
n_fft=args.window_size,
hop_length=args.hop_length,
n_mels=args.mel_bins,
f_min=args.fmin,
f_max=args.fmax,
window='hann',
center=True,
pad_mode='reflect',
to_db=True,
amin=1e-3,
top_db=None)
dst_h5.create_dataset(key, data=x[0].numpy())
src_h5.close()
dst_h5.close()
......@@ -91,23 +96,24 @@ if len(wav_files) > 0:
print(f'{len(wav_files)} wav files listed')
for f in tqdm.tqdm(wav_files):
s, _ = pa.load(f, sr=args.sample_rate)
x = pa.melspectrogram(s,
sr=args.sample_rate,
window_size=args.window_size,
hop_length=args.hop_length,
n_mels=args.mel_bins,
fmin=args.fmin,
fmax=args.fmax,
window='hann',
center=True,
pad_mode='reflect',
ref=1.0,
amin=1e-10,
top_db=None)
x = melspectrogram(paddle.to_tensor(s),
sr=args.sample_rate,
win_length=args.window_size,
n_fft=args.window_size,
hop_length=args.hop_length,
n_mels=args.mel_bins,
f_min=args.fmin,
f_max=args.fmax,
window='hann',
center=True,
pad_mode='reflect',
to_db=True,
amin=1e-3,
top_db=None)
# figure(figsize=(8,8))
# imshow(x)
# show()
# print(x.shape)
key = f.split('/')[-1][:11]
dst_h5.create_dataset(key, data=x)
dst_h5.create_dataset(key, data=x[0].numpy())
dst_h5.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册