未验证 提交 7260766f 编写于 作者: R ranchlai 提交者: GitHub

PaddleAudio framework alpha version (#5311)

* added sound classication

* added liscense, clean code, add pre-commit

* update req

* moved to PaddlePaddle-models

* code re-structure

* update README.md

* update README.md

* Update README.md

* add audioset training

* default resample mode to kaiser_fast

* delete some comments

* precommit check

* sha->rev

* add config.ymal

* remove SoundClassification from paddlespeech, since it's in PaddleAudio now

* add labels

* remove old labels

* update code

* empty

* #5300

* add evaluate, etc

* remove trace|

* import evaluate

* path update

* precommit check

* recover slowfast

* restore README.md to paddle:develop

* refactor

* update readme

* update README.md

* refactor

* refactor

* refactor

* refactor

* precommit fixed

* update README.md

* Update README.md

* Update README.md

* Update train.py

changed prefixed, removed some comments

* add wav file for testing

* bug fixed eval,new checkpoint map=0.416

* Update README.md
Co-authored-by: Nranchlai <=ranchlai@163.com>
上级 47110429
repos:
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
rev: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: a11d9314b22d8f8c7556443875b731ef05965464
rev: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
- id: check-merge-conflict
- id: check-symlinks
......@@ -15,7 +16,7 @@
- id: trailing-whitespace
files: \.md$
- repo: https://github.com/Lucas-C/pre-commit-hooks
sha: v1.0.1
rev: v1.0.1
hooks:
- id: forbid-crlf
files: \.md$
......
......@@ -4,3 +4,4 @@ nohup.out
__pycache__/
*.wav
*.m4a
obsolete/**
repos:
- repo: local
hooks:
- id: yapf
......@@ -8,7 +9,7 @@
files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: a11d9314b22d8f8c7556443875b731ef05965464
rev: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
- id: check-merge-conflict
- id: check-symlinks
......
[style]
based_on_style = pep8
column_limit = 120
column_limit = 80
# PaddleAudio
Unofficial paddle audio codebase
# PaddleAudio: The audio library for PaddlePaddle
## Introduction
PaddleAudio is the audio toolkit to speed up your audio research and development loop in PaddlePaddle. It currently provides a collection of audio datasets, feature-extraction functions, audio transforms,state-of-the-art pre-trained models in sound tagging/classification and anomaly sound detection. More models and features are on the roadmap.
## Features
- Spectrogram and related features are compatible with librosa.
- State-of-the-art models in sound tagging on Audioset, sound classification on esc50, and more to come.
- Ready-to-use audio embedding with a line of code, includes sound embedding and more on the roadmap.
- Data loading supports for common open source audio in multiple languages including English, Mandarin and so on.
## Install
```
git clone https://github.com/ranchlai/PaddleAudio.git
cd PaddleAudio
git clone https://github.com/PaddlePaddle/models
cd models/PaddleAudio
pip install .
```
## Usage
## Quick start
### Audio loading and feature extraction
```
import paddleAudio as pa
import paddleaudio as pa
s,r = pa.load(f)
mel = pa.features.mel_spect(s,r)
mel_spect = pa.melspectrogram(s,sr=r)
```
## to do
- add sound effects(tempo, mag, etc) , sox supports
- add dataset support
- add models DCASE classication ASD,sound classification
- add demos (audio,video demos)
- add openL3 support
### Examples
We provide a set of examples to help you get started in using PaddleAudio quickly.
- [PANNs: acoustic scene and events analysis using pre-trained models](./examples/panns)
- [Environmental Sound classification on ESC-50 dataset](./examples/sound_classification)
- [Training a audio-tagging network on Audioset](./examples/audioset_training)
Please refer to [example directory](./examples) for more details.
# Training Audio-tagging networks using PaddleAudio
In this example, we showcase how to train a typical CNN using PaddleAudio.
Different from [PANNS](https://github.com/qiuqiangkong/audioset_tagging_cnn)\[1\], which used a customized CNN14, here we are using [resnet50](https://arxiv.org/abs/1512.03385v1)\[2\], a kind of network more commonly-used in deep learning community and has much fewer parameters than the networks used in PANNS. We achieved similar [Audioset](https://research.google.com/audioset/) mAP in our version of balanced evaluation dataset.
## Introduction
Audio-tagging is the task to generate tags for input audio file. It is usually tackled as a multi-label classification problem. Different from commonly-seen multi-class classification which is a single-label problem of categorizing instances into precisely one of more than two classes\[3\], multi-label classification task categorizes instances into one or more classes.
To solve the audio-tagging we can borrow ideas from image-recognition community. For example, we can use residual networks trained on imagenet dataset but change the classification head from 1000-classes to 527-classes for Audioset. We can also use BCE-loss to do multi-label classification, avoiding the label-competition brought by softmax activation and cross-entropy loss (since we allow multiple labels or tags for the same audio instance). However, as mentioned by previous work\[1,4,5\], audioset is weakly labelled(No time information for the exact location of the labels) and highly imbalanced(The training dataset is largely dominated by speech and music), the training is highly prone to over-fitting and we have to change the network and training strategy accordingly as follows.
- Use of extra front-end feature extraction, i.e., convert the audio from waveform to mel-spectrogram. In image classification task, no extra feature extraction is necessary(e.g., not need to convert image to frequency domain)
- Use of weight averaging to reduce the variance and improve generalization. This is necessary to reduce over-fitting. (In our example, weight-averaging improved mAP from 0.375 to 0.401)
- Use resnet50 with 1-channel input and add 4 dropout layers each after the 4 convolution blocks in residual networks. This is motivated by PANNS.
- Use of mixup training. we set mixup\[4\] gamma to 0.5 and verify that mixup training is quite useful for this scenario.
- Use the pretrained weight from imagenet classification task . This approach is sometimes used in audio classification task. In our example, it accelerates the training process at the very first epoch.
- Use of learning-rate warmup heuristic. Learning-rate warmup is commonly used in training CNNs and transformers. We found it stabilize the training and improve final mAP.
- Use balanced sampling to make sure each class of the Audioset is treated evenly.
- Use random cropping, spectrogram permutation, time-freq masking to do training augmentation.
With the above strategies, we have achieved results that on par with the sota, and our network size is much smaller in terms of FLOPs of number of parameters(see the following table). When saved to disk, the weight of this work is only 166 MiB in size, in contrast to 466MiB for CNN14.
| Model | Flops | Params |
| :------------- | :----------: | -----------: |
| CNN14* | 1,658,961,018 | 80,769,871 |
| Resnet50(This example) | 1,327,513,088 | 28,831,567|
(* Note: we use ```paddle.flops()``` to calculate flops and parameters, which gives slightly different results from the origin paper)
## Features
The feature we use in this example is mel-spectrogram, similar to that of [PANNS](https://github.com/qiuqiangkong/audioset_tagging_cnn). The details of the feature parameters are listed in [config.yaml](config.yaml) and also described below:
```
sample_rate: 32000
window_size: 1024
hop_size: 640
mel_bins: 128
fmin: 50
fmax: 16000
```
We provide a script file [wav2mel](./wav2mel.py) to preprocess audio files into mel-spectrograms and stored them as multiple h5 files. You can use it to do preprocessing if you have already downloaded the Audioset.
## Network
Since the input audio feature is a one-channel spectrogram,
we modify the resnet50 to accept 1-channel inputs by setting conv1 as
```
self.conv1 = nn.Conv2D(1, self.inplanes, kernel_size=7, stride=2, padding=3, bias_attr=False)
```
As mentioned, we add dropout layers each after the convolution block in resnet50.
We also added a dropout layer before the last fully-connect layer and set the dropout rate to ```0.2```
The network is defined in [model.py](./model.py)
## Training
Everything about data, feature, training controls, etc, can be found in the [config file](./assets/config.yaml).
In this section we will describe training data and steps on how to run the training.
### Training data
The dataset used in both training and evaluation is [Audioset](https://research.google.com/audioset/). We manually download the video files from Youtube according to the youtube-id listed in the dataset, and convert the audio to wav format of 1-channel and 32K sample rate. We then extract the melspetrogram features as described above and store the features as numpy array into separated h5 file. Each h5 file contains features extracted from 10,000 audio files.
For this experience we have successfully downloaded ```1714174``` valid files for unbalance segment, ```16906``` for balanced training segment, and ```17713``` for balanced evaluation segment. The data statistics are summarized in the following table:
| | unbalanced | balanced train |Evaluation |
| :------------- | :----------: | :-----------: |-----------: |
| [Original](https://research.google.com/audioset/download.html) | 2,042,985 | 22,176 | 20,383 |
| [PANNS](https://arxiv.org/pdf/1912.10211.pdf) | 1,913,637 | 20,550 |18,887 |
| This example | 1,714,174 | 16,906 |17,713 |
Our version of dataset contains fewer audio than those of PANNs due to the reasons that video will gradually become private or simply deleted by the authors. We use all of the audio files from balanced segment and unbalanced segment for training and the rest evaluation segment for testing. This gives up 1,714,174 training files (unevenly) distributed across 527 labels. The label information can be found int [this location](https://research.google.com/audioset/ontology/index.html) and the paper\[7\]
### Run the training
Set all necessary path and training configurations in the file [config.yaml](./config.yaml), then run
```
python train.py --device <device_number>
```
for single gpu training. It takes about 3 hours for training one epoch with balance-sampling strategy.
To restore from a checkpoint, run
```
python train.py --device <device_number> --restore <epoch_num>
```
For multi-gpu training, run
```
python -m paddle.distributed.launch --selected_gpus='0,1,2,3' ./train.py --distributed=1
```
### Training loss function
We use mixup loss described in the paper \[4\]. It's better than simply using binary cross entropy loss in the
multi-label Audioset tagging problem.
### Training log
We use [VisualDL](https://github.com/PaddlePaddle/VisualDL.git) to record training loss and evaluation metrics.
![train_loss.png](./assets/train_loss.png)
## Evaluation
We evaluate audio tagging performance using the same metrics as described in PANNS, namely mAP, AUC,d-prime.
Since our version of evaluation dataset is different from PANNs, we re-evaluate the performance of PANNS using their code and pre-trained weights. For the TAL Net\[5\] and DeepRes\[6\], we directly use the results in the original paper.
To get the statistics of our pre-trained model in this example, run
```
python evaluation.py
```
| Model |mAP |AUC |d-prime|
| :------------- | :----------: |:-----------: |-----------: |
| TAL Net \[5\]* | 0.362| 0.965 |2.56|
| DeepRes \[6\]* | 0.392 | 0.971|2.68|
| PANNS \[1\] | 0.420 ** | 0.970|2.66|
| This example | 0.416 | 0.968 |2.62|
(* indicate different evaluation set than ours, ** stats are different from the paper as we re-evaluated on our version of dataset)
## Inference
You can do inference by passing an input audio file to [inference.py](./inference.py)
```
python inference.py --wav_file <path-to-your-wav-file> --top_k 5
```
which will give you a result like this:
```
labels prob
------------
Speech: 0.744
Cat: 0.721
Meow: 0.681
Domestic animal: 0.627
Animal: 0.488
```
## Reference
- \[1\] Kong, Qiuqiang, et al. “PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition.” IEEE Transactions on Audio, Speech, and Language Processing, vol. 28, 2020, pp. 2880–2894.
- \[2\] He, Kaiming, et al. “Deep Residual Learning for Image Recognition.” 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016, pp. 770–778.
- \[3\] https://en.wikipedia.org/wiki/Multi-label_classification
- \[4\] Zhang, Hongyi, et al. “Mixup: Beyond Empirical Risk Minimization.” International Conference on Learning Representations, 2017.
- \[5\] Kong, Qiuqiang, et al. “Audio Set Classification with Attention Model: A Probabilistic Perspective.” 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2018, pp. 316–320.
- \[6\] Ford, Logan, et al. “A Deep Residual Network for Large-Scale Acoustic Scene Analysis.” Interspeech 2019, 2019, pp. 2568–2572.
- \[7]\ Gemmeke, Jort F., et al. “Audio Set: An Ontology and Human-Labeled Dataset for Audio Events.” 2017 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2017, pp. 776–780.
/m/09x0r
/m/05zppz
/m/02zsn
/m/0ytgt
/m/01h8n0
/m/02qldy
/m/0261r1
/m/0brhx
/m/07p6fty
/m/07q4ntr
/m/07rwj3x
/m/07sr1lc
/m/04gy_2
/t/dd00135
/m/03qc9zr
/m/02rtxlg
/m/01j3sz
/t/dd00001
/m/07r660_
/m/07s04w4
/m/07sq110
/m/07rgt08
/m/0463cq4
/t/dd00002
/m/07qz6j3
/m/07qw_06
/m/07plz5l
/m/015lz1
/m/0l14jd
/m/01swy6
/m/02bk07
/m/01c194
/t/dd00003
/t/dd00004
/t/dd00005
/t/dd00006
/m/06bxc
/m/02fxyj
/m/07s2xch
/m/07r4k75
/m/01w250
/m/0lyf6
/m/07mzm6
/m/01d3sd
/m/07s0dtb
/m/07pyy8b
/m/07q0yl5
/m/01b_21
/m/0dl9sf8
/m/01hsr_
/m/07ppn3j
/m/06h7j
/m/07qv_x_
/m/07pbtc8
/m/03cczk
/m/07pdhp0
/m/0939n_
/m/01g90h
/m/03q5_w
/m/02p3nc
/m/02_nn
/m/0k65p
/m/025_jnm
/m/0l15bq
/m/01jg02
/m/01jg1z
/m/053hz1
/m/028ght
/m/07rkbfh
/m/03qtwd
/m/07qfr4h
/t/dd00013
/m/0jbk
/m/068hy
/m/0bt9lr
/m/05tny_
/m/07r_k2n
/m/07qf0zm
/m/07rc7d9
/m/0ghcn6
/t/dd00136
/m/01yrx
/m/02yds9
/m/07qrkrw
/m/07rjwbb
/m/07r81j2
/m/0ch8v
/m/03k3r
/m/07rv9rh
/m/07q5rw0
/m/01xq0k1
/m/07rpkh9
/m/0239kh
/m/068zj
/t/dd00018
/m/03fwl
/m/07q0h5t
/m/07bgp
/m/025rv6n
/m/09b5t
/m/07st89h
/m/07qn5dc
/m/01rd7k
/m/07svc2k
/m/09ddx
/m/07qdb04
/m/0dbvp
/m/07qwf61
/m/01280g
/m/0cdnk
/m/04cvmfc
/m/015p6
/m/020bb7
/m/07pggtn
/m/07sx8x_
/m/0h0rv
/m/07r_25d
/m/04s8yn
/m/07r5c2p
/m/09d5_
/m/07r_80w
/m/05_wcq
/m/01z5f
/m/06hps
/m/04rmv
/m/07r4gkf
/m/03vt0
/m/09xqv
/m/09f96
/m/0h2mp
/m/07pjwq1
/m/01h3n
/m/09ld4
/m/07st88b
/m/078jl
/m/07qn4z3
/m/032n05
/m/04rlf
/m/04szw
/m/0fx80y
/m/0342h
/m/02sgy
/m/018vs
/m/042v_gx
/m/06w87
/m/01glhc
/m/07s0s5r
/m/018j2
/m/0jtg0
/m/04rzd
/m/01bns_
/m/07xzm
/m/05148p4
/m/05r5c
/m/01s0ps
/m/013y1f
/m/03xq_f
/m/03gvt
/m/0l14qv
/m/01v1d8
/m/03q5t
/m/0l14md
/m/02hnl
/m/0cfdd
/m/026t6
/m/06rvn
/m/03t3fj
/m/02k_mr
/m/0bm02
/m/011k_j
/m/01p970
/m/01qbl
/m/03qtq
/m/01sm1g
/m/07brj
/m/05r5wn
/m/0xzly
/m/0mbct
/m/016622
/m/0j45pbj
/m/0dwsp
/m/0dwtp
/m/0dwt5
/m/0l156b
/m/05pd6
/m/01kcd
/m/0319l
/m/07gql
/m/07c6l
/m/0l14_3
/m/02qmj0d
/m/07y_7
/m/0d8_n
/m/01xqw
/m/02fsn
/m/085jw
/m/0l14j_
/m/06ncr
/m/01wy6
/m/03m5k
/m/0395lw
/m/03w41f
/m/027m70_
/m/0gy1t2s
/m/07n_g
/m/0f8s22
/m/026fgl
/m/0150b9
/m/03qjg
/m/0mkg
/m/0192l
/m/02bxd
/m/0l14l2
/m/07kc_
/m/0l14t7
/m/01hgjl
/m/064t9
/m/0glt670
/m/02cz_7
/m/06by7
/m/03lty
/m/05r6t
/m/0dls3
/m/0dl5d
/m/07sbbz2
/m/05w3f
/m/06j6l
/m/0gywn
/m/06cqb
/m/01lyv
/m/015y_n
/m/0gg8l
/m/02x8m
/m/02w4v
/m/06j64v
/m/03_d0
/m/026z9
/m/0ggq0m
/m/05lls
/m/02lkt
/m/03mb9
/m/07gxw
/m/07s72n
/m/0283d
/m/0m0jc
/m/08cyft
/m/0fd3y
/m/07lnk
/m/0g293
/m/0ln16
/m/0326g
/m/0155w
/m/05fw6t
/m/02v2lh
/m/0y4f8
/m/0z9c
/m/0164x2
/m/0145m
/m/02mscn
/m/016cjb
/m/028sqc
/m/015vgc
/m/0dq0md
/m/06rqw
/m/02p0sh1
/m/05rwpb
/m/074ft
/m/025td0t
/m/02cjck
/m/03r5q_
/m/0l14gg
/m/07pkxdp
/m/01z7dr
/m/0140xf
/m/0ggx5q
/m/04wptg
/t/dd00031
/t/dd00032
/t/dd00033
/t/dd00034
/t/dd00035
/t/dd00036
/t/dd00037
/m/03m9d0z
/m/09t49
/t/dd00092
/m/0jb2l
/m/0ngt1
/m/0838f
/m/06mb1
/m/07r10fb
/t/dd00038
/m/0j6m2
/m/0j2kx
/m/05kq4
/m/034srq
/m/06wzb
/m/07swgks
/m/02_41
/m/07pzfmf
/m/07yv9
/m/019jd
/m/0hsrw
/m/056ks2
/m/02rlv9
/m/06q74
/m/012f08
/m/0k4j
/m/0912c9
/m/07qv_d5
/m/02mfyn
/m/04gxbd
/m/07rknqz
/m/0h9mv
/t/dd00134
/m/0ltv
/m/07r04
/m/0gvgw0
/m/05x_td
/m/02rhddq
/m/03cl9h
/m/01bjv
/m/03j1ly
/m/04qvtq
/m/012n7d
/m/012ndj
/m/04_sv
/m/0btp2
/m/06d_3
/m/07jdr
/m/04zmvq
/m/0284vy3
/m/01g50p
/t/dd00048
/m/0195fx
/m/0k5j
/m/014yck
/m/04229
/m/02l6bg
/m/09ct_
/m/0cmf2
/m/0199g
/m/06_fw
/m/02mk9
/t/dd00065
/m/08j51y
/m/01yg9g
/m/01j4z9
/t/dd00066
/t/dd00067
/m/01h82_
/t/dd00130
/m/07pb8fc
/m/07q2z82
/m/02dgv
/m/03wwcy
/m/07r67yg
/m/02y_763
/m/07rjzl8
/m/07r4wb8
/m/07qcpgn
/m/07q6cd_
/m/0642b4
/m/0fqfqc
/m/04brg2
/m/023pjk
/m/07pn_8q
/m/0dxrf
/m/0fx9l
/m/02pjr4
/m/02jz0l
/m/0130jx
/m/03dnzn
/m/03wvsk
/m/01jt3m
/m/012xff
/m/04fgwm
/m/0d31p
/m/01s0vc
/m/03v3yw
/m/0242l
/m/01lsmm
/m/02g901
/m/05rj2
/m/0316dw
/m/0c2wf
/m/01m2v
/m/081rb
/m/07pp_mv
/m/07cx4
/m/07pp8cl
/m/01hnzm
/m/02c8p
/m/015jpf
/m/01z47d
/m/046dlr
/m/03kmc9
/m/0dgbq
/m/030rvx
/m/01y3hg
/m/0c3f7m
/m/04fq5q
/m/0l156k
/m/06hck5
/t/dd00077
/m/02bm9n
/m/01x3z
/m/07qjznt
/m/07qjznl
/m/0l7xg
/m/05zc1
/m/0llzx
/m/02x984l
/m/025wky1
/m/024dl
/m/01m4t
/m/0dv5r
/m/07bjf
/m/07k1x
/m/03l9g
/m/03p19w
/m/01b82r
/m/02p01q
/m/023vsd
/m/0_ksk
/m/01d380
/m/014zdl
/m/032s66
/m/04zjc
/m/02z32qm
/m/0_1c
/m/073cg4
/m/0g6b5
/g/122z_qxw
/m/07qsvvw
/m/07pxg6y
/m/07qqyl4
/m/083vt
/m/07pczhz
/m/07pl1bw
/m/07qs1cx
/m/039jq
/m/07q7njn
/m/07rn7sz
/m/04k94
/m/07rrlb6
/m/07p6mqd
/m/07qlwh6
/m/07r5v4s
/m/07prgkl
/m/07pqc89
/t/dd00088
/m/07p7b8y
/m/07qlf79
/m/07ptzwd
/m/07ptfmf
/m/0dv3j
/m/0790c
/m/0dl83
/m/07rqsjt
/m/07qnq_y
/m/07rrh0c
/m/0b_fwt
/m/02rr_
/m/07m2kt
/m/018w8
/m/07pws3f
/m/07ryjzk
/m/07rdhzs
/m/07pjjrj
/m/07pc8lb
/m/07pqn27
/m/07rbp7_
/m/07pyf11
/m/07qb_dv
/m/07qv4k0
/m/07pdjhy
/m/07s8j8t
/m/07plct2
/t/dd00112
/m/07qcx4z
/m/02fs_r
/m/07qwdck
/m/07phxs1
/m/07rv4dm
/m/07s02z0
/m/07qh7jl
/m/07qwyj0
/m/07s34ls
/m/07qmpdm
/m/07p9k1k
/m/07qc9xj
/m/07rwm0c
/m/07phhsh
/m/07qyrcz
/m/07qfgpx
/m/07rcgpl
/m/07p78v5
/t/dd00121
/m/07s12q4
/m/028v0c
/m/01v_m0
/m/0b9m1
/m/0hdsk
/m/0c1dj
/m/07pt_g0
/t/dd00125
/t/dd00126
/t/dd00127
/t/dd00128
/t/dd00129
/m/01b9nn
/m/01jnbd
/m/096m7z
/m/06_y0by
/m/07rgkc5
/m/06xkwv
/m/0g12c5
/m/08p9q4
/m/07szfh9
/m/0chx_
/m/0cj0r
/m/07p_0gm
/m/01jwx6
/m/07c52
/m/06bz3
/m/07hvw1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
sample_rate: 32000
window_size: 1024
hop_size: 640
mel_bins: 128
fmin: 50
fmax: 16000
batch_size : 64
val_batch_size: 8
num_classes: 527
mixup : True
mixup_alpha : 1.0
max_mel_len : 501
mel_crop_len : 480 # used for augmentation
balanced_sampling : True
epoch_num : 500
# for training from scratch
start_lr : 0.0003
warm_steps: 1000
# for fine-tuning
#start_lr : 0.00001
#add dropout in resnet precedding the fc layer
dropout: 0.20
# set the data path accordingly
balanced_train_h5 : './audioset/balanced_train.h5'
unbalanced_train_h5 : './audioset/unbalanced/*.h5'
balanced_eval_h5 : './audioset/balanced_eval.h5'
audioset_label: './assets/audioset_labels.txt'
model_dir : './checkpoints/'
log_path : './log'
checkpoint_step : 5000
lr_dec_per_step : 60000
num_workers : 0
max_time_mask: 5
max_freq_mask: 5
max_time_mask_width: 60
max_freq_mask_width: 60
model_type : 'resnet50' # resnet18,resnet50,resnet101
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import json
import os
import subprocess
import time
import warnings
import h5py
import librosa
import numpy as np
import paddle
import paddleaudio
import yaml
from paddle.io import DataLoader, Dataset, IterableDataset
from paddleaudio import augment
from utils import get_labels, get_ytid_clsidx_mapping
def spect_permute(spect, tempo_axis, nblocks):
"""spectrogram permutaion"""
assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
if tempo_axis == 0:
nt, nf = spect.shape
else:
nf, nt = spect.shape
if nblocks <= 1:
return spect
block_width = nt // nblocks + 1
if tempo_axis == 1:
blocks = [
spect[:, block_width * i:(i + 1) * block_width]
for i in range(nblocks)
]
np.random.shuffle(blocks)
new_spect = np.concatenate(blocks, 1)
else:
blocks = [
spect[block_width * i:(i + 1) * block_width, :]
for i in range(nblocks)
]
np.random.shuffle(blocks)
new_spect = np.concatenate(blocks, 0)
return new_spect
def random_choice(a):
i = np.random.randint(0, high=len(a))
return a[int(i)]
def get_keys(file_pointers):
all_keys = []
key2file = {}
for fp in file_pointers:
all_keys += list(fp.keys())
key2file.update({k: fp for k in fp.keys()})
return all_keys, key2file
class H5AudioSet(Dataset):
"""
Dataset class for Audioset, with mel features stored in multiple hdf5 files.
The h5 files store mel-spectrogram features pre-extracted from wav files.
Use wav2mel.py to do feature extraction.
"""
def __init__(self,
h5_files,
config,
augment=True,
training=True,
balanced_sampling=True):
super(H5AudioSet, self).__init__()
self.h5_files = h5_files
self.config = config
self.file_pointers = [h5py.File(f) for f in h5_files]
self.all_keys, self.key2file = get_keys(self.file_pointers)
self.augment = augment
self.training = training
self.balanced_sampling = balanced_sampling
print(
f'{len(self.h5_files)} h5 files, totally {len(self.all_keys)} audio files listed'
)
self.ytid2clsidx, self.clsidx2ytid = get_ytid_clsidx_mapping()
def _process(self, x):
assert x.shape[0] == self.config[
'mel_bins'], 'the first dimension must be mel frequency'
target_len = self.config['max_mel_len']
if x.shape[1] <= target_len:
pad_width = (target_len - x.shape[1]) // 2 + 1
x = np.pad(x, ((0, 0), (pad_width, pad_width)))
x = x[:, :target_len]
if self.training and self.augment:
x = augment.random_crop2d(x,
self.config['mel_crop_len'],
tempo_axis=1)
x = spect_permute(x, tempo_axis=1, nblocks=random_choice([0, 2, 3]))
aug_level = random_choice([0.2, 0.1, 0])
x = augment.adaptive_spect_augment(x, tempo_axis=1, level=aug_level)
return x.T
def __getitem__(self, idx):
if self.balanced_sampling:
cls_id = int(np.random.randint(0, self.config['num_classes']))
keys = self.clsidx2ytid[cls_id]
k = random_choice(self.all_keys)
cls_ids = self.ytid2clsidx[k]
else:
idx = idx % len(self.all_keys)
k = self.all_keys[idx]
cls_ids = self.ytid2clsidx[k]
fp = self.key2file[k]
x = fp[k][:, :]
x = self._process(x)
y = np.zeros((self.config['num_classes'], ), 'float32')
y[cls_ids] = 1.0
return x, y
def __len__(self):
return len(self.all_keys)
def get_ytid2labels(segment_csv):
"""
compute the mapping (dict object) from youtube id to audioset labels.
"""
with open(segment_csv) as F:
lines = F.read().split('\n')
lines = [l for l in lines if len(l) > 0 and l[0] != '#']
ytid2labels = {l.split(',')[0]: l.split('"')[-2] for l in lines}
return ytid2labels
def worker_init(worker_id):
time.sleep(worker_id / 32)
np.random.seed(int(time.time()) % 100 + worker_id)
def get_train_loader(config):
train_h5_files = glob.glob(config['unbalanced_train_h5'])
train_h5_files += [config['balanced_train_h5']]
train_dataset = H5AudioSet(train_h5_files,
config,
balanced_sampling=config['balanced_sampling'],
augment=True,
training=True)
train_loader = DataLoader(train_dataset,
shuffle=True,
batch_size=config['batch_size'],
drop_last=True,
num_workers=config['num_workers'],
use_buffer_reader=True,
use_shared_memory=True,
worker_init_fn=worker_init)
return train_loader
def get_val_loader(config):
val_dataset = H5AudioSet([config['balanced_eval_h5']],
config,
balanced_sampling=False,
augment=False)
val_loader = DataLoader(val_dataset,
shuffle=False,
batch_size=config['val_batch_size'],
drop_last=False,
num_workers=config['num_workers'])
return val_loader
if __name__ == '__main__':
# do some testing here
with open('./assets/config.yaml') as f:
config = yaml.safe_load(f)
train_h5_files = glob.glob(config['unbalanced_train_h5'])
dataset = H5AudioSet(train_h5_files,
config,
balanced_sampling=True,
augment=True,
training=True)
x, y = dataset[1]
print(x.shape, y.shape)
dataset = H5AudioSet([config['balanced_eval_h5']],
config,
balanced_sampling=False,
augment=False)
x, y = dataset[0]
print(x.shape, y.shape)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import numpy as np
import paddle
import paddle.nn.functional as F
import paddleaudio as pa
import yaml
from dataset import get_val_loader
from model import resnet50
from paddle.utils import download
from sklearn.metrics import average_precision_score, roc_auc_score
from utils import compute_dprime,download_assets
checkpoint_url = 'https://bj.bcebos.com/paddleaudio/paddleaudio/resnet50_weight_averaging_mAP0.416.pdparams'
def evaluate(epoch, val_loader, model, loss_fn):
model.eval()
avg_loss = 0.0
all_labels = []
all_preds = []
for batch_id, (x, y) in enumerate(val_loader()):
x = x.unsqueeze((1))
label = y
logits = model(x)
loss_val = loss_fn(logits, label)
pred = F.sigmoid(logits)
all_labels += [label.numpy()]
all_preds += [pred.numpy()]
avg_loss = (avg_loss * batch_id + loss_val.numpy()[0]) / (1 + batch_id)
msg = f'eval epoch:{epoch}, batch:{batch_id}'
msg += f'|{len(val_loader)}'
msg += f',loss:{avg_loss:.3}'
if batch_id % 20 == 0:
print(msg)
all_preds = np.concatenate(all_preds, 0)
all_labels = np.concatenate(all_labels, 0)
mAP_score = np.mean(
average_precision_score(all_labels, all_preds, average=None))
auc_score = np.mean(roc_auc_score(all_labels, all_preds, average=None))
dprime = compute_dprime(auc_score)
return avg_loss, mAP_score, auc_score, dprime
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Audioset inference')
parser.add_argument('--config',
type=str,
required=False,
default='./assets/config.yaml')
parser.add_argument('--device',
help='set the gpu device number',
type=int,
required=False,
default=0)
parser.add_argument('--weight', type=str, required=False, default='')
args = parser.parse_args()
download_assets()
with open(args.config) as f:
c = yaml.safe_load(f)
paddle.set_device('gpu:{}'.format(args.device))
ModelClass = eval(c['model_type'])
model = ModelClass(pretrained=False,
num_classes=c['num_classes'],
dropout=c['dropout'])
if args.weight.strip() == '':
print(f'Using pretrained weight: {checkpoint_url}')
args.weight = download.get_weights_path_from_url(checkpoint_url)
model.load_dict(paddle.load(args.weight))
model.eval()
val_loader = get_val_loader(c)
print(f'Evaluating...')
avg_loss, mAP_score, auc_score, dprime = evaluate(
0, val_loader, model, F.binary_cross_entropy_with_logits)
print(f'mAP: {mAP_score:.3}')
print(f'auc: {auc_score:.3}')
print(f'd-prime: {dprime:.3}')
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import numpy as np
import paddle
import paddle.nn.functional as F
import paddleaudio as pa
import yaml
from model import resnet50
from paddle.utils import download
from utils import (download_assets, get_label_name_mapping, get_labels,
get_metrics)
download_assets()
checkpoint_url = 'https://bj.bcebos.com/paddleaudio/paddleaudio/resnet50_weight_averaging_mAP0.416.pdparams'
def load_and_extract_feature(file, c):
s, _ = pa.load(file, sr=c['sample_rate'])
x = pa.features.melspectrogram(s,
sr=c['sample_rate'],
window_size=c['window_size'],
hop_length=c['hop_size'],
n_mels=c['mel_bins'],
fmin=c['fmin'],
fmax=c['fmax'],
window='hann',
center=True,
pad_mode='reflect',
ref=1.0,
amin=1e-10,
top_db=None)
x = x.T # !!
x = paddle.Tensor(x).unsqueeze((0, 1))
return x
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Audioset inference')
parser.add_argument('--device',
help='set the gpu device number',
type=int,
required=False,
default=0)
parser.add_argument('--config',
type=str,
required=False,
default='./assets/config.yaml')
parser.add_argument('--weight', type=str, required=False, default='')
parser.add_argument('--wav_file',
type=str,
required=False,
default='./wav/TKtNAJa-mbQ_11.000.wav')
parser.add_argument('--top_k', type=int, required=False, default=20)
args = parser.parse_args()
top_k = args.top_k
label2name, name2label = get_label_name_mapping()
with open(args.config) as f:
c = yaml.safe_load(f)
paddle.set_device('gpu:{}'.format(args.device))
ModelClass = eval(c['model_type'])
model = ModelClass(pretrained=False,
num_classes=c['num_classes'],
dropout=c['dropout'])
if args.weight.strip() == '':
args.weight = download.get_weights_path_from_url(checkpoint_url)
model.load_dict(paddle.load(args.weight))
model.eval()
x = load_and_extract_feature(args.wav_file, c)
labels = get_labels()
logits = model(x)
pred = F.sigmoid(logits)
pred = pred[0].cpu().numpy()
clsidx = np.argsort(pred)[-top_k:][::-1]
probs = np.sort(pred)[-top_k:][::-1]
for i, idx in enumerate(clsidx):
name = label2name[labels[idx]]
print(name, probs[i])
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division, print_function
import paddle
import paddle.nn as nn
from paddle.utils.download import get_weights_path_from_url
__all__ = [
'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'
]
model_urls = {
'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams',
'cf548f46534aa3560945be4b95cd11c4'),
'resnet34': ('https://paddle-hapi.bj.bcebos.com/models/resnet34.pdparams',
'8d2275cf8706028345f78ac0e1d31969'),
'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams',
'ca6f485ee1ab0492d38f323885b0ad80'),
'resnet101': ('https://paddle-hapi.bj.bcebos.com/models/resnet101.pdparams',
'02f35f034ca3858e1e54d4036443c92d'),
'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams',
'7ad16a2f1e7333859ff986138630fd7a'),
}
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2D
if dilation > 1:
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock")
self.conv1 = nn.Conv2D(inplanes,
planes,
3,
padding=1,
stride=stride,
bias_attr=False)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class BottleneckBlock(nn.Layer):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(BottleneckBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2D
width = int(planes * (base_width / 64.)) * groups
self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False)
self.bn1 = norm_layer(width)
self.conv2 = nn.Conv2D(width,
width,
3,
padding=dilation,
stride=stride,
groups=groups,
dilation=dilation,
bias_attr=False)
self.bn2 = norm_layer(width)
self.conv3 = nn.Conv2D(width,
planes * self.expansion,
1,
bias_attr=False)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Layer):
"""ResNet model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Block (BasicBlock|BottleneckBlock): block module of model.
depth (int): layers of resnet, default: 50.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True.
Examples:
.. code-block:: python
from paddle.vision.models import ResNet
from paddle.vision.models.resnet import BottleneckBlock, BasicBlock
resnet50 = ResNet(BottleneckBlock, 50)
resnet18 = ResNet(BasicBlock, 18)
"""
def __init__(self,
block,
depth,
num_classes=1000,
with_pool=True,
dropout=0.5):
super(ResNet, self).__init__()
layer_cfg = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3]
}
layers = layer_cfg[depth]
self.num_classes = num_classes
self.with_pool = with_pool
self._norm_layer = nn.BatchNorm2D
self.inplanes = 64
self.dilation = 1
self.bn0 = nn.BatchNorm2D(128)
self.conv1 = nn.Conv2D(1,
self.inplanes,
kernel_size=7,
stride=2,
padding=3,
bias_attr=False)
self.bn1 = self._norm_layer(self.inplanes)
self.relu = nn.ReLU()
self.relu2 = nn.ReLU()
self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.drop1 = nn.Dropout2D(dropout)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.drop2 = nn.Dropout2D(dropout)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.drop3 = nn.Dropout2D(dropout)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.drop4 = nn.Dropout2D(dropout)
self.drop = nn.Dropout(dropout)
self.extra_fc = nn.Linear(512 * block.expansion, 1024 * 2)
if with_pool:
self.avgpool = nn.AdaptiveAvgPool2D((1, 1))
if num_classes > 0:
self.fc = nn.Linear(1024 * 2, num_classes)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(self.inplanes,
planes * block.expansion,
1,
stride=stride,
bias_attr=False),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, 1, 64,
previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = x.transpose([0, 3, 2, 1])
x = self.bn0(x)
x = x.transpose([0, 3, 2, 1])
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.drop1(x)
x = self.layer2(x)
x = self.drop2(x)
x = self.layer3(x)
x = self.drop3(x)
x = self.layer4(x)
x = self.drop4(x)
if self.with_pool:
x = self.avgpool(x)
if self.num_classes > 0:
x = paddle.flatten(x, 1)
x = self.drop(x)
x = self.extra_fc(x)
x = self.relu2(x)
x = self.fc(x)
return x
def _resnet(arch, Block, depth, pretrained, **kwargs):
model = ResNet(Block, depth, **kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
param = paddle.load(weight_path)
model.set_dict(param)
return model
def resnet18(pretrained=False, **kwargs):
"""ResNet 18-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet18
# build model
model = resnet18()
# build model and load imagenet pretrained weight
# model = resnet18(pretrained=True)
"""
return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs)
def resnet34(pretrained=False, **kwargs):
"""ResNet 34-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet34
# build model
model = resnet34()
# build model and load imagenet pretrained weight
# model = resnet34(pretrained=True)
"""
return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs)
def resnet50(pretrained=False, **kwargs):
"""ResNet 50-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet50
# build model
model = resnet50()
# build model and load imagenet pretrained weight
# model = resnet50(pretrained=True)
"""
return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs)
def resnet101(pretrained=False, **kwargs):
"""ResNet 101-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet101
# build model
model = resnet101()
# build model and load imagenet pretrained weight
# model = resnet101(pretrained=True)
"""
return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs)
def resnet152(pretrained=False, **kwargs):
"""ResNet 152-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet152
# build model
model = resnet152()
# build model and load imagenet pretrained weight
# model = resnet152(pretrained=True)
"""
return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import os
import time
import numpy as np
import paddle
import paddle.distributed as dist
import paddle.nn.functional as F
import yaml
from dataset import get_train_loader, get_val_loader
from evaluate import evaluate
from model import resnet18, resnet50, resnet101
from paddle.io import DataLoader, Dataset, IterableDataset
from paddle.optimizer import Adam
from utils import (MixUpLoss, get_metrics, load_checkpoint, mixup_data,
save_checkpoint)
from visualdl import LogWriter
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Audioset training')
parser.add_argument('--device', type=int, required=False, default=1)
parser.add_argument('--restore', type=int, required=False, default=-1)
parser.add_argument('--config',
type=str,
required=False,
default='./assets/config.yaml')
parser.add_argument('--distributed', type=int, required=False, default=0)
args = parser.parse_args()
with open(args.config) as f:
c = yaml.safe_load(f)
log_writer = LogWriter(logdir=c['log_path'])
prefix = 'mixup_{}'.format(c['model_type'])
if args.distributed != 0:
dist.init_parallel_env()
local_rank = dist.get_rank()
else:
paddle.set_device('gpu:{}'.format(args.device))
local_rank = 0
print(f'using ' + c['model_type'])
ModelClass = eval(c['model_type'])
#define loss
bce_loss = F.binary_cross_entropy_with_logits
loss_fn = MixUpLoss(bce_loss)
warm_steps = c['warm_steps']
lrs = np.linspace(1e-10, c['start_lr'], warm_steps)
# restore checkpoint
if args.restore != -1:
model = ModelClass(pretrained=False,
num_classes=c['num_classes'],
dropout=c['dropout'])
model_dict, optim_dict = load_checkpoint(c['model_dir'], args.restore,
prefix)
model.load_dict(model_dict)
optimizer = Adam(learning_rate=c['start_lr'],
parameters=model.parameters())
optimizer.set_state_dict(optim_dict)
start_epoch = args.restore
else:
model = ModelClass(pretrained=True,
num_classes=c['num_classes'],
dropout=c['dropout']) # use imagenet pretrained
optimizer = Adam(learning_rate=c['start_lr'],
parameters=model.parameters())
start_epoch = 0
#for name,p in list(model.named_parameters())[:-2]:
# print(name,p.stop_gradient)
# p.stop_gradient = True
os.makedirs(c['model_dir'], exist_ok=True)
if args.distributed != 0:
model = paddle.DataParallel(model)
train_loader = get_train_loader(c)
val_loader = get_val_loader(c)
epoch_num = c['epoch_num']
if args.restore != -1:
avg_loss, mAP_score, auc_score, dprime = evaluate(
args.restore, val_loader, model, bce_loss)
print(f'average map at epoch {args.restore} is {mAP_score}')
print(f'auc_score: {auc_score}')
print(f'd-prime: {dprime}')
best_mAP = mAP_score
log_writer.add_scalar(tag="eval mAP",
step=args.restore,
value=mAP_score)
log_writer.add_scalar(tag="eval auc",
step=args.restore,
value=auc_score)
log_writer.add_scalar(tag="eval dprime",
step=args.restore,
value=dprime)
else:
best_mAP = 0.0
step = 0
for epoch in range(start_epoch, epoch_num):
avg_loss = 0.0
avg_preci = 0.0
avg_recall = 0.0
model.train()
model.clear_gradients()
t0 = time.time()
for batch_id, (x,y) in enumerate(train_loader()):
if step < warm_steps:
optimizer.set_lr(lrs[step])
x.stop_gradient = False
if c['balanced_sampling']:
x = x.squeeze()
y = y.squeeze()
x = x.unsqueeze((1))
if c['mixup']:
mixed_x, mixed_y = mixup_data(x, y, c['mixup_alpha'])
logits = model(mixed_x)
loss_val = loss_fn(logits, mixed_y)
loss_val.backward()
else:
logits = model(x)
loss_val = bce_loss(logits, y)
loss_val.backward()
optimizer.step()
model.clear_gradients()
pred = F.sigmoid(logits)
preci, recall = get_metrics(y.squeeze().numpy(), pred.numpy())
avg_loss = (avg_loss * batch_id + loss_val.numpy()[0]) / (1 +
batch_id)
avg_preci = (avg_preci * batch_id + preci) / (1 + batch_id)
avg_recall = (avg_recall * batch_id + recall) / (1 + batch_id)
elapsed = (time.time() - t0) / 3600
remain = elapsed / (1 + batch_id) * (len(train_loader) - batch_id)
msg = f'epoch:{epoch}, batch:{batch_id}'
msg += f'|{len(train_loader)}'
msg += f',loss:{avg_loss:.3}'
msg += f',recall:{avg_recall:.3}'
msg += f',preci:{avg_preci:.3}'
msg += f',elapsed:{elapsed:.1}h'
msg += f',remained:{remain:.1}h'
if batch_id % 20 == 0 and local_rank == 0:
print(msg)
log_writer.add_scalar(tag="train loss",
step=step,
value=avg_loss)
log_writer.add_scalar(tag="train preci",
step=step,
value=avg_preci)
log_writer.add_scalar(tag="train recall",
step=step,
value=avg_recall)
step += 1
if step % c['checkpoint_step'] == 0 and local_rank == 0:
save_checkpoint(c['model_dir'], epoch, model, optimizer, prefix)
avg_loss, avg_map, auc_score, dprime = evaluate(
epoch, val_loader, model, bce_loss)
print(f'average map at epoch {epoch} is {avg_map}')
print(f'auc: {auc_score}')
print(f'd-prime: {dprime}')
log_writer.add_scalar(tag="eval mAP", step=epoch, value=avg_map)
log_writer.add_scalar(tag="eval auc",
step=epoch,
value=auc_score)
log_writer.add_scalar(tag="eval dprime",
step=epoch,
value=dprime)
model.train()
model.clear_gradients()
if avg_map > best_mAP:
print('mAP improved from {} to {}'.format(
best_mAP, avg_map))
best_mAP = avg_map
fn = os.path.join(
c['model_dir'],
f'{prefix}_epoch{epoch}_mAP{avg_map:.3}.pdparams')
paddle.save(model.state_dict(), fn)
else:
print(f'mAP {avg_map} did not improved from {best_mAP}')
if step % c['lr_dec_per_step'] == 0 and step != 0:
if optimizer.get_lr() <= 1e-6:
factor = 0.95
else:
factor = 0.8
optimizer.set_lr(optimizer.get_lr() * factor)
print('decreased lr to {}'.format(optimizer.get_lr()))
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import numpy as np
import paddle
import paddle.nn.functional as F
from scipy import stats
from sklearn.metrics import average_precision_score
__all__ = [
'save_checkpoint', 'load_checkpoint', 'get_labels', 'random_choice',
'get_label_name_mapping'
]
def random_choice(a):
i = np.random.randint(0, high=len(a), size=(1, ))
return a[int(i)]
def get_labels():
with open('./assets/audioset_labels.txt') as F:
labels = F.read().split('\n')
return labels
def get_ytid_clsidx_mapping():
"""
Compute the mapping between youtube id and class index.
The class index range from 0 to 527, correspoding to the labels stored in audioset_labels.txt file
"""
labels = get_labels()
label2clsidx = {l: i for i, l in enumerate(labels)}
lines = open('./assets/unbalanced_train_segments.csv').read().split('\n')
lines += open('./assets/balanced_train_segments.csv').read().split('\n')
lines += open('./assets/eval_segments.csv').read().split('\n')
lines = [l for l in lines if len(l) > 0 and l[0] != '#']
ytid2clsidx = {}
for l in lines:
ytid = l.split(',')[0]
labels = l.split(',')[3:]
cls_idx = []
for label in labels:
label = label.replace('"', '').strip()
cls_idx.append(label2clsidx[label])
ytid2clsidx.update({ytid: cls_idx})
clsidx2ytid = {i: [] for i in range(527)}
for k in ytid2clsidx.keys():
for v in ytid2clsidx[k]:
clsidx2ytid[v] += [k]
return ytid2clsidx, clsidx2ytid
def get_metrics(label, pred):
a = label
b = (pred > 0.5).astype('int32')
eps = 1e-8
tp = np.sum(b[a == 1])
fp = np.sum(b[a == 0])
precision = tp / (fp + tp + eps)
fn = np.sum(b[a == 1] == 0)
recall = tp / (tp + fn)
return precision, recall
def compute_dprime(auc):
"""Compute d_prime metric.
Reference:
J. F. Gemmeke, D. P. Ellis, D. Freedman, A. Jansen, W. Lawrence, R. C. Moore, M. Plakal, and M. Ritter, “Audio Set: An ontology and humanlabeled dataset for audio events,” in IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2017, pp. 776–780.
"""
dp = stats.norm().ppf(auc) * np.sqrt(2.0)
return dp
def save_checkpoint(model_dir, step, model, optimizer, prefix):
print(f'checkpointing at step {step}')
paddle.save(model.state_dict(),
model_dir + '/{}_checkpoint{}.pdparams'.format(prefix, step))
paddle.save(optimizer.state_dict(),
model_dir + '/{}_checkpoint{}.pdopt'.format(prefix, step))
def load_checkpoint(model_dir, epoch, prefix):
file = model_dir + '/{}_checkpoint_model{}.tar'.format(prefix, epoch)
print('loading checkpoing ' + file)
model_dict = paddle.load(model_dir +
'/{}_checkpoint{}.pdparams'.format(prefix, epoch))
optim_dict = paddle.load(model_dir +
'/{}_checkpoint{}.pdopt'.format(prefix, epoch))
return model_dict, optim_dict
def get_label_name_mapping():
with open('./assets/ontology.json') as F:
ontology = json.load(F)
label2name = {o['id']: o['name'] for o in ontology}
name2label = {o['name']: o['id'] for o in ontology}
return label2name, name2label
def download_assets():
os.makedirs('./assets/', exist_ok=True)
urls = [
'https://raw.githubusercontent.com/audioset/ontology/master/ontology.json',
'http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/eval_segments.csv',
'http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/balanced_train_segments.csv',
'http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/unbalanced_train_segments.csv'
]
for url in urls:
fname = './assets/' + url.split('/')[-1]
if os.path.exists(fname):
continue
cmd = 'wget ' + url + ' -O ' + fname
print(cmd)
os.system(cmd)
class MixUpLoss(paddle.nn.Layer):
"""Define the mixup loss used in training audioset.
Reference:
Zhang, Hongyi, et al. “Mixup: Beyond Empirical Risk Minimization.” International Conference on Learning Representations, 2017.
"""
def __init__(self, criterion):
super(MixUpLoss, self).__init__()
self.criterion = criterion
def forward(self, pred, mixup_target):
assert type(mixup_target) in [
tuple, list
] and len(mixup_target
) == 3, 'mixup data should be tuple consists of (ya,yb,lamda)'
ya, yb, lamda = mixup_target
return lamda * self.criterion(pred, ya) \
+ (1 - lamda) * self.criterion(pred, yb)
def extra_repr(self):
return 'MixUpLoss with {}'.format(self.criterion)
def mixup_data(x, y, alpha=1.0):
"""Mix the input data and label using mixup strategy, returns mixed inputs,
pairs of targets, and lambda
Reference:
Zhang, Hongyi, et al. “Mixup: Beyond Empirical Risk Minimization.” International Conference on Learning Representations, 2017.
"""
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.shape[0]
index = paddle.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * paddle.index_select(x, index)
y_a, y_b = y, paddle.index_select(y, index)
mixed_target = (y_a, y_b, lam)
return mixed_x, mixed_target
......@@ -4,9 +4,8 @@ import os
import h5py
import numpy as np
import tqdm
import paddleaudio as pa
import tqdm
parser = argparse.ArgumentParser(description='wave2mel')
parser.add_argument('--wav_file', type=str, required=False, default='')
......@@ -20,10 +19,12 @@ parser.add_argument('--dst_h5_file', type=str, required=False, default='')
parser.add_argument('--sample_rate', type=int, required=False, default=32000)
parser.add_argument('--window_size', type=int, required=False, default=1024)
parser.add_argument('--mel_bins', type=int, required=False, default=128)
parser.add_argument('--hop_length', type=int, required=False, default=640) #20ms
parser.add_argument('--hop_length', type=int, required=False,
default=640) #20ms
parser.add_argument('--fmin', type=int, required=False, default=50) #25ms
parser.add_argument('--fmax', type=int, required=False, default=16000) #25ms
parser.add_argument('--skip_existed', type=int, required=False, default=1) #25ms
parser.add_argument('--skip_existed', type=int, required=False,
default=1) #25ms
args = parser.parse_args()
......@@ -63,19 +64,19 @@ if len(h5_files) > 0:
s = src_h5[key][:]
s = pa.depth_convert(s, 'float32')
# s = pa.resample(s,32000,args.sample_rate)
x = pa.features.mel_spect(s,
sample_rate=args.sample_rate,
window_size=args.window_size,
hop_length=args.hop_length,
mel_bins=args.mel_bins,
fmin=args.fmin,
fmax=args.fmax,
window='hann',
center=True,
pad_mode='reflect',
ref=1.0,
amin=1e-10,
top_db=None)
x = pa.features.melspectrogram(s,
sr=args.sample_rate,
window_size=args.window_size,
hop_length=args.hop_length,
n_mels=args.mel_bins,
fmin=args.fmin,
fmax=args.fmax,
window='hann',
center=True,
pad_mode='reflect',
ref=1.0,
amin=1e-10,
top_db=None)
dst_h5.create_dataset(key, data=x)
src_h5.close()
dst_h5.close()
......@@ -90,20 +91,19 @@ if len(wav_files) > 0:
print(f'{len(wav_files)} wav files listed')
for f in tqdm.tqdm(wav_files):
s, _ = pa.load(f, sr=args.sample_rate)
# s = pa.resample(s,32000,args.sample_rate)
x = pa.features.mel_spect(s,
sample_rate=args.sample_rate,
window_size=args.window_size,
hop_length=args.hop_length,
mel_bins=args.mel_bins,
fmin=args.fmin,
fmax=args.fmax,
window='hann',
center=True,
pad_mode='reflect',
ref=1.0,
amin=1e-10,
top_db=None)
x = pa.melspectrogram(s,
sr=args.sample_rate,
window_size=args.window_size,
hop_length=args.hop_length,
n_mels=args.mel_bins,
fmin=args.fmin,
fmax=args.fmax,
window='hann',
center=True,
pad_mode='reflect',
ref=1.0,
amin=1e-10,
top_db=None)
# figure(figsize=(8,8))
# imshow(x)
# show()
......
Speech
Male speech, man speaking
Female speech, woman speaking
Child speech, kid speaking
Conversation
Narration, monologue
Babbling
Speech synthesizer
Shout
Bellow
Whoop
Yell
Battle cry
Children shouting
Screaming
Whispering
Laughter
Baby laughter
Giggle
Snicker
Belly laugh
Chuckle, chortle
Crying, sobbing
Baby cry, infant cry
Whimper
Wail, moan
Sigh
Singing
Choir
Yodeling
Chant
Mantra
Male singing
Female singing
Child singing
Synthetic singing
Rapping
Humming
Groan
Grunt
Whistling
Breathing
Wheeze
Snoring
Gasp
Pant
Snort
Cough
Throat clearing
Sneeze
Sniff
Run
Shuffle
Walk, footsteps
Chewing, mastication
Biting
Gargling
Stomach rumble
Burping, eructation
Hiccup
Fart
Hands
Finger snapping
Clapping
Heart sounds, heartbeat
Heart murmur
Cheering
Applause
Chatter
Crowd
Hubbub, speech noise, speech babble
Children playing
Animal
Domestic animals, pets
Dog
Bark
Yip
Howl
Bow-wow
Growling
Whimper (dog)
Cat
Purr
Meow
Hiss
Caterwaul
Livestock, farm animals, working animals
Horse
Clip-clop
Neigh, whinny
Cattle, bovinae
Moo
Cowbell
Pig
Oink
Goat
Bleat
Sheep
Fowl
Chicken, rooster
Cluck
Crowing, cock-a-doodle-doo
Turkey
Gobble
Duck
Quack
Goose
Honk
Wild animals
Roaring cats (lions, tigers)
Roar
Bird
Bird vocalization, bird call, bird song
Chirp, tweet
Squawk
Pigeon, dove
Coo
Crow
Caw
Owl
Hoot
Bird flight, flapping wings
Canidae, dogs, wolves
Rodents, rats, mice
Mouse
Patter
Insect
Cricket
Mosquito
Fly, housefly
Buzz
Bee, wasp, etc.
Frog
Croak
Snake
Rattle
Whale vocalization
Music
Musical instrument
Plucked string instrument
Guitar
Electric guitar
Bass guitar
Acoustic guitar
Steel guitar, slide guitar
Tapping (guitar technique)
Strum
Banjo
Sitar
Mandolin
Zither
Ukulele
Keyboard (musical)
Piano
Electric piano
Organ
Electronic organ
Hammond organ
Synthesizer
Sampler
Harpsichord
Percussion
Drum kit
Drum machine
Drum
Snare drum
Rimshot
Drum roll
Bass drum
Timpani
Tabla
Cymbal
Hi-hat
Wood block
Tambourine
Rattle (instrument)
Maraca
Gong
Tubular bells
Mallet percussion
Marimba, xylophone
Glockenspiel
Vibraphone
Steelpan
Orchestra
Brass instrument
French horn
Trumpet
Trombone
Bowed string instrument
String section
Violin, fiddle
Pizzicato
Cello
Double bass
Wind instrument, woodwind instrument
Flute
Saxophone
Clarinet
Harp
Bell
Church bell
Jingle bell
Bicycle bell
Tuning fork
Chime
Wind chime
Change ringing (campanology)
Harmonica
Accordion
Bagpipes
Didgeridoo
Shofar
Theremin
Singing bowl
Scratching (performance technique)
Pop music
Hip hop music
Beatboxing
Rock music
Heavy metal
Punk rock
Grunge
Progressive rock
Rock and roll
Psychedelic rock
Rhythm and blues
Soul music
Reggae
Country
Swing music
Bluegrass
Funk
Folk music
Middle Eastern music
Jazz
Disco
Classical music
Opera
Electronic music
House music
Techno
Dubstep
Drum and bass
Electronica
Electronic dance music
Ambient music
Trance music
Music of Latin America
Salsa music
Flamenco
Blues
Music for children
New-age music
Vocal music
A capella
Music of Africa
Afrobeat
Christian music
Gospel music
Music of Asia
Carnatic music
Music of Bollywood
Ska
Traditional music
Independent music
Song
Background music
Theme music
Jingle (music)
Soundtrack music
Lullaby
Video game music
Christmas music
Dance music
Wedding music
Happy music
Funny music
Sad music
Tender music
Exciting music
Angry music
Scary music
Wind
Rustling leaves
Wind noise (microphone)
Thunderstorm
Thunder
Water
Rain
Raindrop
Rain on surface
Stream
Waterfall
Ocean
Waves, surf
Steam
Gurgling
Fire
Crackle
Vehicle
Boat, Water vehicle
Sailboat, sailing ship
Rowboat, canoe, kayak
Motorboat, speedboat
Ship
Motor vehicle (road)
Car
Vehicle horn, car horn, honking
Toot
Car alarm
Power windows, electric windows
Skidding
Tire squeal
Car passing by
Race car, auto racing
Truck
Air brake
Air horn, truck horn
Reversing beeps
Ice cream truck, ice cream van
Bus
Emergency vehicle
Police car (siren)
Ambulance (siren)
Fire engine, fire truck (siren)
Motorcycle
Traffic noise, roadway noise
Rail transport
Train
Train whistle
Train horn
Railroad car, train wagon
Train wheels squealing
Subway, metro, underground
Aircraft
Aircraft engine
Jet engine
Propeller, airscrew
Helicopter
Fixed-wing aircraft, airplane
Bicycle
Skateboard
Engine
Light engine (high frequency)
Dental drill, dentist's drill
Lawn mower
Chainsaw
Medium engine (mid frequency)
Heavy engine (low frequency)
Engine knocking
Engine starting
Idling
Accelerating, revving, vroom
Door
Doorbell
Ding-dong
Sliding door
Slam
Knock
Tap
Squeak
Cupboard open or close
Drawer open or close
Dishes, pots, and pans
Cutlery, silverware
Chopping (food)
Frying (food)
Microwave oven
Blender
Water tap, faucet
Sink (filling or washing)
Bathtub (filling or washing)
Hair dryer
Toilet flush
Toothbrush
Electric toothbrush
Vacuum cleaner
Zipper (clothing)
Keys jangling
Coin (dropping)
Scissors
Electric shaver, electric razor
Shuffling cards
Typing
Typewriter
Computer keyboard
Writing
Alarm
Telephone
Telephone bell ringing
Ringtone
Telephone dialing, DTMF
Dial tone
Busy signal
Alarm clock
Siren
Civil defense siren
Buzzer
Smoke detector, smoke alarm
Fire alarm
Foghorn
Whistle
Steam whistle
Mechanisms
Ratchet, pawl
Clock
Tick
Tick-tock
Gears
Pulleys
Sewing machine
Mechanical fan
Air conditioning
Cash register
Printer
Camera
Single-lens reflex camera
Tools
Hammer
Jackhammer
Sawing
Filing (rasp)
Sanding
Power tool
Drill
Explosion
Gunshot, gunfire
Machine gun
Fusillade
Artillery fire
Cap gun
Fireworks
Firecracker
Burst, pop
Eruption
Boom
Wood
Chop
Splinter
Crack
Glass
Chink, clink
Shatter
Liquid
Splash, splatter
Slosh
Squish
Drip
Pour
Trickle, dribble
Gush
Fill (with liquid)
Spray
Pump (liquid)
Stir
Boiling
Sonar
Arrow
Whoosh, swoosh, swish
Thump, thud
Thunk
Electronic tuner
Effects unit
Chorus effect
Basketball bounce
Bang
Slap, smack
Whack, thwack
Smash, crash
Breaking
Bouncing
Whip
Flap
Scratch
Scrape
Rub
Roll
Crushing
Crumpling, crinkling
Tearing
Beep, bleep
Ping
Ding
Clang
Squeal
Creak
Rustle
Whir
Clatter
Sizzle
Clicking
Clickety-clack
Rumble
Plop
Jingle, tinkle
Hum
Zing
Boing
Crunch
Silence
Sine wave
Harmonic
Chirp tone
Sound effect
Pulse
Inside, small room
Inside, large room or hall
Inside, public space
Outside, urban or manmade
Outside, rural or natural
Reverberation
Echo
Noise
Environmental noise
Static
Mains hum
Distortion
Sidetone
Cacophony
White noise
Pink noise
Throbbing
Vibration
Television
Radio
Field recording
......@@ -52,13 +52,15 @@ def split(waveform: np.ndarray, win_size: int, hop_size: int):
return time, data
def batchify(data: List[List[float]], sample_rate: int, batch_size: int, **kwargs):
def batchify(data: List[List[float]], sample_rate: int, batch_size: int,
**kwargs):
"""
Extract features from waveforms and create batches.
"""
examples = []
for waveform in data:
feats = mel_spect(waveform, sample_rate=sample_rate, **kwargs).transpose()
feats = mel_spect(waveform, sample_rate=sample_rate,
**kwargs).transpose()
examples.append(feats)
# Seperates data into some batches.
......@@ -72,7 +74,10 @@ def batchify(data: List[List[float]], sample_rate: int, batch_size: int, **kwarg
yield one_batch
def predict(model, data: List[List[float]], sample_rate: int, batch_size: int = 1):
def predict(model,
data: List[List[float]],
sample_rate: int,
batch_size: int = 1):
"""
Use pretrained model to make predictions.
"""
......@@ -96,7 +101,8 @@ if __name__ == '__main__':
paddle.set_device(args.device)
model = cnn14(pretrained=True, extract_embedding=False)
waveform, sr = load_audio(args.wav, sr=None)
time, data = split(waveform, int(args.sample_duration * sr), int(args.hop_duration * sr))
time, data = split(waveform, int(args.sample_duration * sr),
int(args.hop_duration * sr))
results = predict(model, data, sr, batch_size=8)
if not os.path.exists(args.output_dir):
......
......@@ -73,7 +73,9 @@ if __name__ == "__main__":
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
output_file = os.path.join(args.output_dir, os.path.basename(args.tagging_file).split('.')[0] + '.txt')
output_file = os.path.join(
args.output_dir,
os.path.basename(args.tagging_file).split('.')[0] + '.txt')
with open(output_file, 'w') as f:
for time, score in zip(times, scores):
f.write(f'{time}\n')
......
import argparse
import glob
import os
import h5py
import numpy as np
import tqdm
import paddleaudio as pa
#from pylab import *
parser = argparse.ArgumentParser(description='wave2mel')
parser.add_argument('--wav_file', type=str, required=False, default='')
parser.add_argument('--wav_list', type=str, required=False, default='')
parser.add_argument('--wav_h5_file', type=str, required=False, default='')
parser.add_argument('--wav_h5_list', type=str, required=False, default='')
parser.add_argument('--output_folder', type=str, required=False, default='./')
parser.add_argument('--output_h5', type=bool, required=False, default=True)
parser.add_argument('--sample_rate', type=int, required=False, default=32000)
parser.add_argument('--window_size', type=int, required=False, default=1024)
parser.add_argument('--mel_bins', type=int, required=False, default=128)
parser.add_argument('--hop_length', type=int, required=False, default=640) #20ms
parser.add_argument('--fmin', type=int, required=False, default=50) #25ms
parser.add_argument('--fmax', type=int, required=False, default=16000) #25ms
args = parser.parse_args()
#args.wav_h5_file = '/ssd2/laiyongquan/audioset/h5/audioset_unblance_group28.h5'
assert not (args.wav_h5_file == '' and args.wav_h5_list == ''\
and args.wav_list == '' and args.wav_file == ''), 'one of wav_file,wav_list,\
wav_h5_file,wav_h5_list needs to specify'
if args.wav_h5_file != '':
h5_files = [args.wav_h5_file]
if args.wav_h5_list != '':
h5_files = open(args.wav_h5_list).read().split('\n')
h5_files = [h for h in h5_files if len(h.strip()) != 0]
dst_folder = args.output_folder
print(f'{len(h5_files)} h5 files listed')
for f in h5_files:
print(f'processing {f}')
dst_file = os.path.join(dst_folder, f.split('/')[-1])
print(f'target file {dst_file}')
assert not os.path.exists(dst_file), f'target file {dst_file} existed'
src_h5 = h5py.File(f)
dst_h5 = h5py.File(dst_file, "w")
for key in tqdm.tqdm(src_h5.keys()):
s = src_h5[key][:]
s = pa.depth_convert(s, 'float32')
# s = pa.resample(s,32000,args.sample_rate)
x = pa.features.mel_spect(s,
sample_rate=args.sample_rate,
window_size=args.window_size,
hop_length=args.hop_length,
mel_bins=args.mel_bins,
fmin=args.fmin,
fmax=args.fmax,
window='hann',
center=True,
pad_mode='reflect',
ref=1.0,
amin=1e-10,
top_db=None)
# figure(figsize=(8,8))
# imshow(x)
# show()
# print(x.shape)
dst_h5.create_dataset(key, data=x)
src_h5.close()
dst_h5.close()
import argparse
import glob
import os
import h5py
import numpy as np
import tqdm
import paddleaudio as pa
#from pylab import *
parser = argparse.ArgumentParser(description='wave2mel')
parser.add_argument('--wav_file', type=str, required=False, default='')
parser.add_argument('--wav_list', type=str, required=False, default='')
parser.add_argument('--wav_h5_file', type=str, required=False, default='')
parser.add_argument('--wav_h5_list', type=str, required=False, default='')
parser.add_argument('--output_folder', type=str, required=False, default='./')
parser.add_argument('--output_h5', type=bool, required=False, default=True)
parser.add_argument('--sample_rate', type=int, required=False, default=32000)
parser.add_argument('--window_size', type=int, required=False, default=1024)
parser.add_argument('--mel_bins', type=int, required=False, default=128)
parser.add_argument('--hop_length', type=int, required=False, default=640) #20ms
parser.add_argument('--fmin', type=int, required=False, default=50) #25ms
parser.add_argument('--fmax', type=int, required=False, default=16000) #25ms
parser.add_argument('--skip_existed', type=int, required=False, default=1) #25ms
args = parser.parse_args()
#args.wav_h5_file = '/ssd2/laiyongquan/audioset/h5/audioset_unblance_group28.h5'
assert not (args.wav_h5_file == '' and args.wav_h5_list == ''\
and args.wav_list == '' and args.wav_file == ''), 'one of wav_file,wav_list,\
wav_h5_file,wav_h5_list needs to specify'
if args.wav_h5_file != '':
h5_files = [args.wav_h5_file]
if args.wav_h5_list != '':
h5_files = open(args.wav_h5_list).read().split('\n')
h5_files = [h for h in h5_files if len(h.strip()) != 0]
dst_folder = args.output_folder
print(f'{len(h5_files)} h5 files listed')
for f in h5_files:
print(f'processing {f}')
dst_file = os.path.join(dst_folder, f.split('/')[-1])
print(f'target file {dst_file}')
if args.skip_existed != 0 and os.path.exists(dst_file):
print(f'skipped file {f}')
continue
assert not os.path.exists(dst_file), f'target file {dst_file} existed'
src_h5 = h5py.File(f)
dst_h5 = h5py.File(dst_file, "w")
for key in tqdm.tqdm(src_h5.keys()):
s = src_h5[key][:]
s = pa.depth_convert(s, 'float32')
# s = pa.resample(s,32000,args.sample_rate)
x = pa.features.mel_spect(s,
sample_rate=args.sample_rate,
window_size=args.window_size,
hop_length=args.hop_length,
mel_bins=args.mel_bins,
fmin=args.fmin,
fmax=args.fmax,
window='hann',
center=True,
pad_mode='reflect',
ref=1.0,
amin=1e-10,
top_db=None)
# figure(figsize=(8,8))
# imshow(x)
# show()
# print(x.shape)
dst_h5.create_dataset(key, data=x)
src_h5.close()
dst_h5.close()
from .backends import *
from .features import *
......@@ -11,131 +11,156 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import List, Optional, Tuple, Type, Union
import numpy as np
import resampy
import soundfile as sf
from numpy import ndarray as array
from scipy.io import wavfile
try:
import librosa
has_librosa = True
except:
has_librosa = False
try:
import soundfile as sf
has_snf = True
except:
has_snf = False
try:
import resampy
has_resampy = True
except:
has_resampy = False
from ..utils import ParameterError
__norm_types__ = ['linear', 'gaussian']
__mono_types__ = ['ch0', 'ch1', 'random', 'average']
__all__ = [
'resample',
'to_mono',
'depth_convert',
'normalize',
'save_wav',
'load',
]
NORMALMIZE_TYPES = ['linear', 'gaussian']
MERGE_TYPES = ['ch0', 'ch1', 'random', 'average']
RESAMPLE_MODES = ['kaiser_best', 'kaiser_fast']
EPS = 1e-8
__all__ = ['resample', 'to_mono', 'depth_convert', 'normalize', 'save', 'load']
def resample(y: array,
src_sr: int,
target_sr: int,
mode: str = 'kaiser_fast') -> array:
""" Audio resampling
def resample(y, src_sr, target_sr):
This function is the same as using resampy.resample().
warnings.warn(
f'Using resampy to {src_sr}=>{target_sr}. This function is pretty slow, we recommend to process audio using ffmpeg'
)
assert type(y) == np.ndarray, 'currently only numpy data are supported'
assert type(
src_sr) == int and src_sr > 0 and src_sr <= 48000, 'make sure type(sr) == int and sr > 0 and sr <= 48000,'
assert type(
target_sr
) == int and target_sr > 0 and target_sr <= 48000, 'make sure type(sr) == int and sr > 0 and sr <= 48000,'
Notes:
The default mode is kaiser_fast. For better audio quality, use mode = 'kaiser_fast'
if has_resampy:
return resampy.resample(y, src_sr, target_sr)
"""
if has_librosa:
return librosa.resample(y, src_sr, target_sr)
if mode == 'kaiser_best':
warnings.warn(
f'Using resampy in kaiser_best to {src_sr}=>{target_sr}. This function is pretty slow, \
we recommend the mode kaiser_fast in large scale audio trainning')
assert False, 'requires librosa or resampy to do resampling, pip install resampy'
if not isinstance(y, np.ndarray):
raise ParameterError(
'Only support numpy array, but received y in {type(y)}')
if mode not in RESAMPLE_MODES:
raise ParameterError(f'resample mode must in {RESAMPLE_MODES}')
def to_mono(y, mono_type='average'):
return resampy.resample(y, src_sr, target_sr, filter=mode)
assert type(y) == np.ndarray, 'currently only numpy data are supported'
if mono_type not in __mono_types__:
assert False, 'Unsupported mono_type {}, available types are {}'.format(mono_type, __mono_types__)
if y.ndim == 1:
return y
def to_mono(y: array, merge_type: str = 'average') -> array:
""" convert sterior audio to mono
"""
if merge_type not in MERGE_TYPES:
raise ParameterError(
f'Unsupported merge type {merge_type}, available types are {MERGE_TYPES}'
)
if y.ndim > 2:
assert False, 'Unsupported audio array, y.ndim > 2, the shape is {}'.format(y.shape)
if mono_type == 'ch0':
raise ParameterError(
f'Unsupported audio array, y.ndim > 2, the shape is {y.shape}')
if y.ndim == 1: # nothing to merge
return y
if merge_type == 'ch0':
return y[0]
if mono_type == 'ch1':
if merge_type == 'ch1':
return y[1]
if mono_type == 'random':
if merge_type == 'random':
return y[np.random.randint(0, 2)]
# need to do averaging according to dtype
if y.dtype == 'float32':
return (y[0] + y[1]) * 0.5
if y.dtype == 'int16':
y1 = y.astype('int32')
y1 = (y1[0] + y1[1]) // 2
y1 = np.clip(y1, np.iinfo(y.dtype).min, np.iinfo(y.dtype).max).astype(y.dtype)
return y1
if y.dtype == 'int8':
y1 = y.astype('int16')
y1 = (y1[0] + y1[1]) // 2
y1 = np.clip(y1, np.iinfo(y.dtype).min, np.iinfo(y.dtype).max).astype(y.dtype)
return y1
assert False, 'Unsupported audio array type, y.dtype={}'.format(y.dtype)
def __safe_cast__(y, dtype):
y_out = (y[0] + y[1]) * 0.5
elif y.dtype == 'int16':
y_out = y.astype('int32')
y_out = (y_out[0] + y_out[1]) // 2
y_out = np.clip(y_out,
np.iinfo(y.dtype).min,
np.iinfo(y.dtype).max).astype(y.dtype)
elif y.dtype == 'int8':
y_out = y.astype('int16')
y_out = (y_out[0] + y_out[1]) // 2
y_out = np.clip(y_out,
np.iinfo(y.dtype).min,
np.iinfo(y.dtype).max).astype(y.dtype)
else:
raise ParameterError(f'Unsupported dtype: {y.dtype}')
return y_out
def _safe_cast(y: array, dtype: Union[type, str]) -> array:
""" data type casting in a safe way, i.e., prevent overflow or underflow
This function is used internally.
"""
return np.clip(y, np.iinfo(dtype).min, np.iinfo(dtype).max).astype(dtype)
def depth_convert(y, dtype): # convert audio array to target dtype
def depth_convert(y: array,
dtype: Union[type, str],
dithering: bool = True) -> array:
"""Convert audio array to target dtype safely
This function convert audio waveform to a target dtype, with addition steps of
preventing overflow/underflow and preserving audio range.
assert type(y) == np.ndarray, 'currently only numpy data are supported'
"""
if dithering:
warnings.warn('dithering is not implemented')
__eps__ = 1e-5
__supported_dtype__ = ['int16', 'int8', 'float32', 'float64']
if y.dtype not in __supported_dtype__:
assert False, 'Unsupported audio dtype, y.dtype is {}, supported dtypes are {}'.format(
y.dtype, __supported_dtype__)
if dtype not in __supported_dtype__:
assert False, 'Unsupported dtype, target dtype is {}, supported dtypes are {}'.format(
dtype, __supported_dtype__)
SUPPORT_DTYPE = ['int16', 'int8', 'float32', 'float64']
if y.dtype not in SUPPORT_DTYPE:
raise ParameterError(
f'Unsupported audio dtype, '
'y.dtype is {y.dtype}, supported dtypes are {SUPPORT_DTYPE}')
if dtype not in SUPPORT_DTYPE:
raise ParameterError(
f'Unsupported audio dtype, '
'target dtype is {dtype}, supported dtypes are {SUPPORT_DTYPE}')
if dtype == y.dtype:
return y
if dtype == 'float64' and y.dtype == 'float32':
return __safe_cast__(y, dtype)
return _safe_cast(y, dtype)
if dtype == 'float32' and y.dtype == 'float64':
return __safe_cast__(y, dtype)
return _safe_cast(y, dtype)
if dtype == 'int16' or dtype == 'int8':
if y.dtype in ['float64', 'float32']:
factor = np.iinfo(dtype).max
y = np.clip(y * factor, np.iinfo(dtype).min, np.iinfo(dtype).max).astype(dtype)
y = np.clip(y * factor,
np.iinfo(dtype).min,
np.iinfo(dtype).max).astype(dtype)
y = y.astype(dtype)
# figure
# plot(y)
# show()
else:
if dtype == 'int16' and y.dtype == 'int8':
factor = np.iinfo('int16').max / np.iinfo('int8').max - __eps__
factor = np.iinfo('int16').max / np.iinfo('int8').max - EPS
y = y.astype('float32') * factor
y = y.astype('int16')
else: #dtype == 'int8' and y.dtype=='int16':
y = y.astype('int32') * np.iinfo('int8').max / np.iinfo('int16').max
else: # dtype == 'int8' and y.dtype=='int16':
y = y.astype('int32') * np.iinfo('int8').max / \
np.iinfo('int16').max
y = y.astype('int8')
if dtype in ['float32', 'float64']:
......@@ -144,7 +169,18 @@ def depth_convert(y, dtype): # convert audio array to target dtype
return y
def sound_file_load(file, offset=None, dtype='int16', duration=None):
def sound_file_load(file: str,
offset: Optional[float] = None,
dtype: str = 'int16',
duration: Optional[int] = None) -> Tuple[array, int]:
"""Load audio using soundfile library
This function load audio file using libsndfile.
Reference:
http://www.mega-nerd.com/libsndfile/#Features
"""
with sf.SoundFile(file) as sf_desc:
sr_native = sf_desc.samplerate
if offset:
......@@ -158,84 +194,117 @@ def sound_file_load(file, offset=None, dtype='int16', duration=None):
return y, sf_desc.samplerate
def normalize(y, norm_type='linear', mul_factor=1.0):
def audio_file_load():
"""Load audio using audiofile library
assert type(y) == np.ndarray, 'currently only numpy data are supported'
This function load audio file using audiofile.
Reference:
https://audiofile.68k.org/
"""
raise NotImplementedError()
def sox_file_load():
"""Load audio using sox library
This function load audio file using sox.
Reference:
http://sox.sourceforge.net/
"""
raise NotImplementedError()
def normalize(y: array,
norm_type: str = 'linear',
mul_factor: float = 1.0) -> array:
""" normalize an input audio with additional multiplier.
"""
__eps__ = 1e-8
#set_trace()
if norm_type == 'linear':
# amin = np.min(y)
amax = np.max(np.abs(y))
factor = 1.0 / (amax + __eps__)
factor = 1.0 / (amax + EPS)
y = y * factor * mul_factor
elif norm_type == 'gaussian':
amean = np.mean(y)
mul_factor = max(0.01, min(mul_factor, 0.2))
astd = np.std(y)
y = mul_factor * (y - amean) / (astd + __eps__)
astd = max(astd, EPS)
y = mul_factor * (y - amean) / astd
else:
assert False, 'not implemented error, norm_type should be in {}'.format(__norm_types__)
raise NotImplementedError(f'norm_type should be in {NORMALMIZE_TYPES}')
return y
def save(y, sr, file):
assert type(y) == np.ndarray, 'currently only numpy data are supported'
assert type(sr) == int and sr > 0 and sr <= 48000, 'make sure type(sr) == int and sr > 0 and sr <= 48000,'
def save_wav(y: array, sr: int, file: str) -> None:
"""Save audio file to disk.
This function saves audio to disk using scipy.io.wavfile, with additional step
to convert input waveform to int16 unless it already is int16
Notes:
It only support raw wav format.
"""
if not file.endswith('.wav'):
raise ParameterError(
f'only .wav file supported, but dst file name is: {file}')
if sr <= 0:
raise ParameterError(
f'Sample rate should be larger than 0, recieved sr = {sr}')
if y.dtype not in ['int16', 'int8']:
warnings.warn('input data type is {}, saving data to int16 format'.format(y.dtype))
yout = depth_convert(y, 'int16')
warnings.warn(
f'input data type is {y.dtype}, will convert data to int16 format before saving'
)
y_out = depth_convert(y, 'int16')
else:
yout = y
y_out = y
wavfile.write(file, sr, y)
wavfile.write(file, sr, y_out)
def load(
file,
sr=None,
mono=True,
mono_type='average', # ch0,ch1,random,average
normal=True,
norm_type='linear',
norm_mul_factor=1.0,
offset=0.0,
duration=None,
dtype='float32'):
if has_librosa:
y, r = librosa.load(file, sr=sr, mono=False, offset=offset, duration=duration,
dtype='float32') #alwasy load in float32, then convert to target dtype
elif has_snf:
y, r = sound_file_load(file, offset=offset, dypte=dtype, duration=duration)
file: str,
sr: Optional[int] = None,
mono: bool = True,
merge_type: str = 'average', # ch0,ch1,random,average
normal: bool = True,
norm_type: str = 'linear',
norm_mul_factor: float = 1.0,
offset: float = 0.0,
duration: Optional[int] = None,
dtype: str = 'float32',
resample_mode: str = 'kaiser_fast') -> Tuple[array, int]:
"""Load audio file from disk.
This function loads audio from disk using using audio beackend.
else:
assert False, 'not implemented error'
Parameters:
Notes:
"""
y, r = sound_file_load(file, offset=offset, dtype=dtype, duration=duration)
##
assert (y.ndim == 1 and len(y) > 0) or (y.ndim == 2 and len(y[0]) > 0), 'audio file {} looks empty'.format(file)
if not ((y.ndim == 1 and len(y) > 0) or (y.ndim == 2 and len(y[0]) > 0)):
raise ParameterError(f'audio file {file} looks empty')
if mono:
y = to_mono(y, mono_type)
y = to_mono(y, merge_type)
if sr is not None and sr != r:
y = resample(y, r, sr)
y = resample(y, r, sr, mode=resample_mode)
r = sr
if normal:
# print('before nom',np.max(y))
y = normalize(y, norm_type, norm_mul_factor)
# print('after norm',np.max(y))
#plot(y)
#show()
if dtype in ['int8', 'int16'] and (normalize == False or normalize == True and norm_type == 'guassian'):
y = normalize(y, 'linear', 1.0) # do normalization before converting to target dtype
elif dtype in ['int8', 'int16']:
# still need to do normalization, before depth convertion
y = normalize(y, 'linear', 1.0)
y = depth_convert(y, dtype)
#figure
#plot(y)
#show()
return y, r
......@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .augmentation import *
from .features import *
from .augment import *
from .core import *
......@@ -12,29 +12,66 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, List, Optional, Tuple, TypeVar
import numpy as np
import paddle
from numpy import ndarray as array
from paddleaudio.backends import depth_convert
from paddleaudio.utils import ParameterError
__all__ = [
'depth_augment',
'spect_augment',
'random_crop1d',
'random_crop2d',
'adaptive_spect_augment',
]
def randint(high: int) -> int:
"""Generate one random integer in range [0 high)
from ..backends import depth_convert
from .utils import randint, weighted_sampling
This is a helper function for random data augmentaiton
"""
return int(np.random.randint(0, high=high))
__all__ = ['depth_augment', 'spect_augment', 'random_crop1d', 'random_crop2d']
def rand() -> float:
"""Generate one floating-point number in range [0 1)
# example y = depth_augment(y,['int8','int16'],[0.8,0.1])
def depth_augment(y, choices=['int8', 'int16'], probs=[0.5, 0.5]):
assert len(probs) == len(choices), 'number of choices {} must be equal to size of probs {}'.format(
This is a helper function for random data augmentaiton
"""
return float(np.random.rand(1))
def depth_augment(y: array,
choices: List = ['int8', 'int16'],
probs: List[float] = [0.5, 0.5]) -> array:
""" Audio depth augmentation
Do audio depth augmentation to simulate the distortion brought by quantization.
"""
assert len(probs) == len(
choices
), 'number of choices {} must be equal to size of probs {}'.format(
len(choices), len(probs))
k = weighted_sampling(probs)
#k = randint(len(choices))
depth = np.random.choice(choices, p=probs)
src_depth = y.dtype
y1 = depth_convert(y, choices[k])
y1 = depth_convert(y, depth)
y2 = depth_convert(y1, src_depth)
return y2
def adaptive_spect_augment(spect, tempo_axis=0, level=0.1):
def adaptive_spect_augment(spect: array,
tempo_axis: int = 0,
level: float = 0.1) -> array:
"""Do adpative spectrogram augmentation
The level of the augmentation is gowern by the paramter level,
ranging from 0 to 1, with 0 represents no augmentation。
"""
assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
if tempo_axis == 0:
nt, nf = spect.shape
......@@ -47,36 +84,35 @@ def adaptive_spect_augment(spect, tempo_axis=0, level=0.1):
num_time_mask = int(10 * level)
num_freq_mask = int(10 * level)
# num_zeros = num_time_mask*time_mask_width*nf + num_freq_mask*freq_mask_width*nt
# factor = (nt*nf)/(nt*nf-num_zeros)
if tempo_axis == 0:
for i in range(num_time_mask):
for _ in range(num_time_mask):
start = randint(nt - time_mask_width)
spect[start:start + time_mask_width, :] = 0
for i in range(num_freq_mask):
for _ in range(num_freq_mask):
start = randint(nf - freq_mask_width)
spect[:, start:start + freq_mask_width] = 0
else:
for i in range(num_time_mask):
for _ in range(num_time_mask):
start = randint(nt - time_mask_width)
spect[:, start:start + time_mask_width] = 0
for i in range(num_freq_mask):
for _ in range(num_freq_mask):
start = randint(nf - freq_mask_width)
spect[start:start + freq_mask_width, :] = 0
return spect
def spect_augment(
spect,
tempo_axis=0,
max_time_mask=3,
max_freq_mask=3,
max_time_mask_width=30,
max_freq_mask_width=20,
):
def spect_augment(spect: array,
tempo_axis: int = 0,
max_time_mask: int = 3,
max_freq_mask: int = 3,
max_time_mask_width: int = 30,
max_freq_mask_width: int = 20) -> array:
"""Do spectrogram augmentation in both time and freq axis
Reference:
"""
assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
if tempo_axis == 0:
nt, nf = spect.shape
......@@ -89,42 +125,47 @@ def spect_augment(
time_mask_width = randint(max_time_mask_width)
freq_mask_width = randint(max_freq_mask_width)
#print(num_time_mask)
#print(num_freq_mask)
if tempo_axis == 0:
for i in range(num_time_mask):
for _ in range(num_time_mask):
start = randint(nt - time_mask_width)
spect[start:start + time_mask_width, :] = 0
for i in range(num_freq_mask):
for _ in range(num_freq_mask):
start = randint(nf - freq_mask_width)
spect[:, start:start + freq_mask_width] = 0
else:
for i in range(num_time_mask):
for _ in range(num_time_mask):
start = randint(nt - time_mask_width)
spect[:, start:start + time_mask_width] = 0
for i in range(num_freq_mask):
for _ in range(num_freq_mask):
start = randint(nf - freq_mask_width)
spect[start:start + freq_mask_width, :] = 0
return spect
def random_crop1d(y, crop_len):
assert y.ndim == 1, 'only accept 1d tensor or numpy array'
def random_crop1d(y: array, crop_len: int) -> array:
""" Do random cropping on 1d input signal
The input is a 1d signal, typically a sound waveform
"""
if y.ndim != 1:
'only accept 1d tensor or numpy array'
n = len(y)
idx = randint(n - crop_len)
return y[idx:idx + crop_len]
def random_crop2d(s, crop_len, tempo_axis=0): # random crop according to temporal direction
assert tempo_axis < s.ndim, 'axis out of range'
def random_crop2d(s: array, crop_len: int, tempo_axis: int = 0) -> array:
""" Do random cropping for 2D array, typically a spectrogram.
The cropping is done in temporal direction on the time-freq input signal.
"""
if tempo_axis >= s.ndim:
raise ParameterError('axis out of range')
n = s.shape[tempo_axis]
idx = randint(high=n - crop_len)
if type(s) == np.ndarray:
sli = [slice(None) for i in range(s.ndim)]
sli[tempo_axis] = slice(idx, idx + crop_len)
out = s[tuple(sli)]
else:
out = paddle.index_select(s, paddle.Tensor(np.array([i for i in range(idx, idx + crop_len)])), axis=tempo_axis)
sli = [slice(None) for i in range(s.ndim)]
sli[tempo_axis] = slice(idx, idx + crop_len)
out = s[tuple(sli)]
return out
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import List, Optional, Union
import numpy as np
import scipy
from numpy import ndarray as array
from numpy.lib.stride_tricks import as_strided
from paddleaudio.utils import ParameterError
from scipy.signal import get_window
__all__ = [
'stft',
'mfcc',
'hz_to_mel',
'mel_to_hz',
'split_frames',
'mel_frequencies',
'power_to_db',
'compute_fbank_matrix',
'melspectrogram',
'spectrogram',
'mu_encode',
'mu_decode',
]
def pad_center(data: array, size: int, axis: int = -1, **kwargs) -> array:
"""Pad an array to a target length along a target axis.
This differs from `np.pad` by centering the data prior to padding,
analogous to `str.center`
"""
kwargs.setdefault("mode", "constant")
n = data.shape[axis]
lpad = int((size - n) // 2)
lengths = [(0, 0)] * data.ndim
lengths[axis] = (lpad, int(size - n - lpad))
if lpad < 0:
raise ParameterError(("Target size ({size:d}) must be "
"at least input size ({n:d})"))
return np.pad(data, lengths, **kwargs)
def split_frames(x: array,
frame_length: int,
hop_length: int,
axis: int = -1) -> array:
"""Slice a data array into (overlapping) frames.
This function is aligned with librosa.frame
"""
if not isinstance(x, np.ndarray):
raise ParameterError(
f"Input must be of type numpy.ndarray, given type(x)={type(x)}")
if x.shape[axis] < frame_length:
raise ParameterError(f"Input is too short (n={x.shape[axis]:d})"
f" for frame_length={frame_length:d}")
if hop_length < 1:
raise ParameterError(f"Invalid hop_length: {hop_length:d}")
if axis == -1 and not x.flags["F_CONTIGUOUS"]:
warnings.warn(f"librosa.util.frame called with axis={axis} "
"on a non-contiguous input. This will result in a copy.")
x = np.asfortranarray(x)
elif axis == 0 and not x.flags["C_CONTIGUOUS"]:
warnings.warn(f"librosa.util.frame called with axis={axis} "
"on a non-contiguous input. This will result in a copy.")
x = np.ascontiguousarray(x)
n_frames = 1 + (x.shape[axis] - frame_length) // hop_length
strides = np.asarray(x.strides)
new_stride = np.prod(strides[strides > 0] // x.itemsize) * x.itemsize
if axis == -1:
shape = list(x.shape)[:-1] + [frame_length, n_frames]
strides = list(strides) + [hop_length * new_stride]
elif axis == 0:
shape = [n_frames, frame_length] + list(x.shape)[1:]
strides = [hop_length * new_stride] + list(strides)
else:
raise ParameterError(f"Frame axis={axis} must be either 0 or -1")
return as_strided(x, shape=shape, strides=strides)
def _check_audio(y, mono=True) -> bool:
"""Determine whether a variable contains valid audio data.
The audio y must be a np.ndarray, ether 1-channel or two channel
"""
if not isinstance(y, np.ndarray):
raise ParameterError("Audio data must be of type numpy.ndarray")
if y.ndim > 2:
raise ParameterError(
f"Invalid shape for audio ndim={y.ndim:d}, shape={y.shape}")
if mono and y.ndim == 2:
raise ParameterError(
f"Invalid shape for mono audio ndim={y.ndim:d}, shape={y.shape}")
if (mono and len(y) == 0) or (not mono and y.shape[1] < 0):
raise ParameterError(f"Audio is empty ndim={y.ndim:d}, shape={y.shape}")
if not np.issubdtype(y.dtype, np.floating):
raise ParameterError("Audio data must be floating-point")
if not np.isfinite(y).all():
raise ParameterError("Audio buffer is not finite everywhere")
return True
def hz_to_mel(frequencies: Union[float, List[float], array],
htk: bool = False) -> array:
"""Convert Hz to Mels
This function is aligned with librosa.
"""
freq = np.asanyarray(frequencies)
if htk:
return 2595.0 * np.log10(1.0 + freq / 700.0)
# Fill in the linear part
f_min = 0.0
f_sp = 200.0 / 3
mels = (freq - f_min) / f_sp
# Fill in the log-scale part
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
if freq.ndim:
# If we have array data, vectorize
log_t = freq >= min_log_hz
mels[log_t] = min_log_mel + \
np.log(freq[log_t] / min_log_hz) / logstep
elif freq >= min_log_hz:
# If we have scalar data, heck directly
mels = min_log_mel + np.log(freq / min_log_hz) / logstep
return mels
def mel_to_hz(mels: Union[float, List[float], array],
htk: int = False) -> array:
"""Convert mel bin numbers to frequencies.
This function is aligned with librosa.
"""
mel_array = np.asanyarray(mels)
if htk:
return 700.0 * (10.0**(mel_array / 2595.0) - 1.0)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mel_array
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
if mel_array.ndim:
# If we have vector data, vectorize
log_t = mel_array >= min_log_mel
freqs[log_t] = min_log_hz * \
np.exp(logstep * (mel_array[log_t] - min_log_mel))
elif mel_array >= min_log_mel:
# If we have scalar data, check directly
freqs = min_log_hz * np.exp(logstep * (mel_array - min_log_mel))
return freqs
def mel_frequencies(n_mels: int = 128,
fmin: float = 0.0,
fmax: float = 11025.0,
htk: bool = False) -> array:
"""Compute mel frequencies
This function is aligned with librosa.
"""
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = hz_to_mel(fmin, htk=htk)
max_mel = hz_to_mel(fmax, htk=htk)
mels = np.linspace(min_mel, max_mel, n_mels)
return mel_to_hz(mels, htk=htk)
def fft_frequencies(sr: int, n_fft: int) -> array:
"""Compute fourier frequencies.
This function is aligned with librosa.
"""
return np.linspace(0, float(sr) / 2, int(1 + n_fft // 2), endpoint=True)
def compute_fbank_matrix(sr: int,
n_fft: int,
n_mels: int = 128,
fmin: float = 0.0,
fmax: Optional[float] = None,
htk: bool = False,
norm: str = "slaney",
dtype: type = np.float32):
"""Compute fbank matrix.
This funciton is aligned with librosa.
"""
if norm != "slaney":
raise ParameterError('norm must set to slaney')
if fmax is None:
fmax = float(sr) / 2
# Initialize the weights
n_mels = int(n_mels)
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
# Center freqs of each FFT bin
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft)
# 'Center freqs' of mel bands - uniformly spaced between limits
mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax, htk=htk)
fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
if norm == "slaney":
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, np.newaxis]
# Only check weights if f_mel[0] is positive
if not np.all((mel_f[:-2] == 0) | (weights.max(axis=1) > 0)):
# This means we have an empty channel somewhere
warnings.warn("Empty filters detected in mel frequency basis. "
"Some channels will produce empty responses. "
"Try increasing your sampling rate (and fmax) or "
"reducing n_mels.")
return weights
def stft(x: array,
n_fft: int = 2048,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: str = "hann",
center: bool = True,
dtype: type = np.complex64,
pad_mode: str = "reflect") -> array:
"""Short-time Fourier transform (STFT).
This function is aligned with librosa.
"""
_check_audio(x)
# By default, use the entire frame
if win_length is None:
win_length = n_fft
# Set the default hop, if it's not already specified
if hop_length is None:
hop_length = int(win_length // 4)
fft_window = get_window(window, win_length, fftbins=True)
# Pad the window out to n_fft size
fft_window = pad_center(fft_window, n_fft)
# Reshape so that the window can be broadcast
fft_window = fft_window.reshape((-1, 1))
# Pad the time series so that frames are centered
if center:
if n_fft > x.shape[-1]:
warnings.warn(
f"n_fft={n_fft} is too small for input signal of length={x.shape[-1]}"
)
x = np.pad(x, int(n_fft // 2), mode=pad_mode)
elif n_fft > x.shape[-1]:
raise ParameterError(
f"n_fft={n_fft} is too small for input signal of length={x.shape[-1]}"
)
# Window the time series.
x_frames = split_frames(x, frame_length=n_fft, hop_length=hop_length)
# Pre-allocate the STFT matrix
stft_matrix = np.empty((int(1 + n_fft // 2), x_frames.shape[1]),
dtype=dtype,
order="F")
fft = np.fft # use numpy fft as default
# Constrain STFT block sizes to 256 KB
MAX_MEM_BLOCK = 2**8 * 2**10
# how many columns can we fit within MAX_MEM_BLOCK?
n_columns = MAX_MEM_BLOCK // (stft_matrix.shape[0] * stft_matrix.itemsize)
n_columns = max(n_columns, 1)
for bl_s in range(0, stft_matrix.shape[1], n_columns):
bl_t = min(bl_s + n_columns, stft_matrix.shape[1])
stft_matrix[:,
bl_s:bl_t] = fft.rfft(fft_window * x_frames[:, bl_s:bl_t],
axis=0)
return stft_matrix
def power_to_db(spect: array,
ref: float = 1.0,
amin: float = 1e-10,
top_db: Optional[float] = 80.0) -> array:
"""Convert a power spectrogram (amplitude squared) to decibel (dB) units
This computes the scaling ``10 * log10(spect / ref)`` in a numerically
stable way.
This function is aligned with librosa.
"""
spect = np.asarray(spect)
if amin <= 0:
raise ParameterError("amin must be strictly positive")
if np.issubdtype(spect.dtype, np.complexfloating):
warnings.warn(
"power_to_db was called on complex input so phase "
"information will be discarded. To suppress this warning, "
"call power_to_db(np.abs(D)**2) instead.")
magnitude = np.abs(spect)
else:
magnitude = spect
if callable(ref):
# User supplied a function to calculate reference power
ref_value = ref(magnitude)
else:
ref_value = np.abs(ref)
log_spec = 10.0 * np.log10(np.maximum(amin, magnitude))
log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value))
if top_db is not None:
if top_db < 0:
raise ParameterError("top_db must be non-negative")
log_spec = np.maximum(log_spec, log_spec.max() - top_db)
return log_spec
def mfcc(x,
sr: int = 16000,
spect: Optional[array] = None,
n_mfcc: int = 20,
dct_type: int = 2,
norm: str = "ortho",
lifter: int = 0,
**kwargs) -> array:
"""Mel-frequency cepstral coefficients (MFCCs)
This function is NOT strictly aligned with librosa. The following example shows how to get the
same result with librosa:
# paddleaudioe mfcc:
kwargs = {
'window_size':512,
'hop_length':320,
'mel_bins':64,
'fmin':50,
'to_db':False}
a = mfcc(x,
spect=None,
n_mfcc=20,
dct_type=2,
norm='ortho',
lifter=0,
**kwargs)
# librosa mfcc:
spect = librosa.feature.melspectrogram(x,sr=16000,n_fft=512,
win_length=512,
hop_length=320,
n_mels=64, fmin=50)
b = librosa.feature.mfcc(x,
sr=16000,
S=spect,
n_mfcc=20,
dct_type=2,
norm='ortho',
lifter=0)
assert np.mean( (a-b)**2) < 1e-8
"""
if spect is None:
spect = melspectrogram(x, sr=sr, **kwargs)
M = scipy.fftpack.dct(spect, axis=0, type=dct_type, norm=norm)[:n_mfcc]
if lifter > 0:
factor = np.sin(np.pi * np.arange(1, 1 + n_mfcc, dtype=M.dtype) /
lifter)
return M * factor[:, np.newaxis]
elif lifter == 0:
return M
else:
raise ParameterError(
f"MFCC lifter={lifter} must be a non-negative number")
def melspectrogram(x: array,
sr: int = 16000,
window_size: int = 512,
hop_length: int = 320,
n_mels: int = 64,
fmin: int = 50,
fmax: Optional[float] = None,
window: str = 'hann',
center: bool = True,
pad_mode: str = 'reflect',
power: float = 2.0,
to_db: bool = True,
ref: float = 1.0,
amin: float = 1e-10,
top_db: Optional[float] = None) -> array:
"""Compute mel-spectrogram.
Parameters:
x: numpy.ndarray
The input wavform is a numpy array [shape=(n,)]
window_size: int, typically 512, 1024, 2048, etc.
The window size for framing, also used as n_fft for stft
Returns:
The mel-spectrogram in power scale or db scale(default)
Notes:
1. sr is default to 16000, which is commonly used in speech/speaker processing.
2. when fmax is None, it is set to sr//2.
3. this function will convert mel spectgrum to db scale by default. This is different
that of librosa.
"""
_check_audio(x, mono=True)
if len(x) <= 0:
raise ParameterError('The input waveform is empty')
if fmax is None:
fmax = sr // 2
if fmin < 0 or fmin >= fmax:
raise ParameterError('fmin and fmax must statisfy 0<fmin<fmax')
s = stft(x,
n_fft=window_size,
hop_length=hop_length,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode)
spect_power = np.abs(s)**power
fb_matrix = compute_fbank_matrix(sr=sr,
n_fft=window_size,
n_mels=n_mels,
fmin=fmin,
fmax=fmax)
mel_spect = np.matmul(fb_matrix, spect_power)
if to_db:
return power_to_db(mel_spect, ref=ref, amin=amin, top_db=top_db)
else:
return mel_spect
def spectrogram(x: array,
sr: int = 16000,
window_size: int = 512,
hop_length: int = 320,
window: str = 'hann',
center: bool = True,
pad_mode: str = 'reflect',
power: float = 2.0) -> array:
"""Compute spectrogram from an input waveform.
This function is a wrapper for librosa.feature.stft, with addition step to
compute the magnitude of the complex spectrogram.
"""
s = stft(x,
n_fft=window_size,
hop_length=hop_length,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode)
return np.abs(s)**power
def mu_encode(x: array, mu: int = 255, quantized: bool = True) -> array:
"""Mu-law encoding.
Compute the mu-law decoding given an input code.
When quantized is True, the result will be converted to
integer in range [0,mu-1]. Otherwise, the resulting signal
is in range [-1,1]
Reference:
https://en.wikipedia.org/wiki/%CE%9C-law_algorithm
"""
mu = 255
y = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
if quantized:
y = np.floor((y + 1) / 2 * mu + 0.5) # convert to [0 , mu-1]
return y
def mu_decode(y: array, mu: int = 255, quantized: bool = True) -> array:
"""Mu-law decoding.
Compute the mu-law decoding given an input code.
it assumes that the input y is in
range [0,mu-1] when quantize is True and [-1,1] otherwise
Reference:
https://en.wikipedia.org/wiki/%CE%9C-law_algorithm
"""
if mu < 1:
raise ParameterError('mu is typically set as 2**k-1, k=1, 2, 3,...')
mu = mu - 1
if quantized: # undo the quantization
y = y * 2 / mu - 1
x = np.sign(y) / mu * ((1 + mu)**np.abs(y) - 1)
return x
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import librosa
import numpy as np
import paddle
__all__ = ['mel_spect', 'linear_spect', 'log_spect']
#mel
def mel_spect(y,
sample_rate=16000,
window_size=512,
hop_length=320,
mel_bins=64,
fmin=50,
fmax=14000,
window='hann',
center=True,
pad_mode='reflect',
ref=1.0,
amin=1e-10,
top_db=None):
""" compute mel-spectrogram from input waveform y.
Create a Mel filter-bank.
This produces a linear transformation matrix to project
FFT bins onto Mel-frequency bins.
"""
s = librosa.stft(y,
n_fft=window_size,
hop_length=hop_length,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode)
power = np.abs(s)**2
melW = librosa.filters.mel(sr=sample_rate, n_fft=window_size, n_mels=mel_bins, fmin=fmin, fmax=fmax)
mel = np.matmul(melW, power)
db = librosa.power_to_db(mel, ref=ref, amin=amin, top_db=None)
return db
def linear_spect(y,
sample_rate=16000,
window_size=512,
hop_length=320,
window='hann',
center=True,
pad_mode='reflect',
power=2):
s = librosa.stft(y,
n_fft=window_size,
hop_length=hop_length,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode)
return np.abs(s)**power
def log_spect(y,
sample_rate=16000,
window_size=512,
hop_length=320,
window='hann',
center=True,
pad_mode='reflect',
power=2.0,
offset=1.0):
s = librosa.stft(
y,
n_fft=window_size,
hop_length=hop_length,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode,
)
s = np.abs(s)**power
return np.log(offset + s) # remove
# ESC: Environmental Sound Classification
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .cnn6 import CNN6
from .cnn10 import CNN10
from .cnn14 import CNN14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ...utils.log import logger
from .conv import ConvBlock
class CNN10(nn.Layer):
"""
The CNN10(14-layer CNNs) mainly consist of 4 convolutional blocks while each convolutional
block consists of 2 convolutional layers with a kernel size of 3 × 3.
Reference:
PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
https://arxiv.org/pdf/1912.10211.pdf
"""
emb_size = 512
def __init__(self, extract_embedding: bool = True, checkpoint: str = None):
super(CNN10, self).__init__()
self.bn0 = nn.BatchNorm2D(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.fc1 = nn.Linear(512, self.emb_size)
self.fc_audioset = nn.Linear(self.emb_size, 527)
if checkpoint is not None and os.path.isfile(checkpoint):
state_dict = paddle.load(checkpoint)
self.set_state_dict(state_dict)
print(f'Loaded CNN10 pretrained parameters from: {checkpoint}')
else:
print('No valid checkpoints for CNN10. Start training from scratch.')
self.extract_embedding = extract_embedding
def forward(self, x):
x.stop_gradient = False
x = x.transpose([0, 3, 2, 1])
x = self.bn0(x)
x = x.transpose([0, 3, 2, 1])
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = x.mean(axis=3)
x = x.max(axis=2) + x.mean(axis=2)
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.fc1(x))
if self.extract_embedding:
output = F.dropout(x, p=0.5, training=self.training)
else:
output = F.sigmoid(self.fc_audioset(x))
return output
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ...utils.log import logger
from .conv import ConvBlock
class CNN14(nn.Layer):
"""
The CNN14(14-layer CNNs) mainly consist of 6 convolutional blocks while each convolutional
block consists of 2 convolutional layers with a kernel size of 3 × 3.
Reference:
PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
https://arxiv.org/pdf/1912.10211.pdf
"""
emb_size = 2048
def __init__(self, extract_embedding: bool = True, checkpoint: str = None):
super(CNN14, self).__init__()
self.bn0 = nn.BatchNorm2D(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
self.fc1 = nn.Linear(2048, self.emb_size)
self.fc_audioset = nn.Linear(self.emb_size, 527)
if checkpoint is not None and os.path.isfile(checkpoint):
state_dict = paddle.load(checkpoint)
self.set_state_dict(state_dict)
logger.info(f'Loaded CNN14 pretrained parameters from: {checkpoint}')
else:
logger.error('No valid checkpoints for CNN14. Start training from scratch.')
self.extract_embedding = extract_embedding
def forward(self, x):
x.stop_gradient = False
x = x.transpose([0, 3, 2, 1])
x = self.bn0(x)
x = x.transpose([0, 3, 2, 1])
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = x.mean(axis=3)
x = x.max(axis=2) + x.mean(axis=2)
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.fc1(x))
if self.extract_embedding:
output = F.dropout(x, p=0.5, training=self.training)
else:
output = F.sigmoid(self.fc_audioset(x))
return output
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ...utils.log import logger
from .conv import ConvBlock5x5
class CNN6(nn.Layer):
"""
The CNN14(14-layer CNNs) mainly consist of 4 convolutional blocks while each convolutional
block consists of 1 convolutional layers with a kernel size of 5 × 5.
Reference:
PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
https://arxiv.org/pdf/1912.10211.pdf
"""
emb_size = 512
def __init__(self, extract_embedding: bool = True, checkpoint: str = None):
super(CNN6, self).__init__()
self.bn0 = nn.BatchNorm2D(64)
self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
self.fc1 = nn.Linear(512, self.emb_size)
self.fc_audioset = nn.Linear(self.emb_size, 527)
if checkpoint is not None and os.path.isfile(checkpoint):
state_dict = paddle.load(checkpoint)
self.set_state_dict(state_dict)
print(f'Loaded CNN6 pretrained parameters from: {checkpoint}')
else:
print('No valid checkpoints for CNN6. Start training from scratch.')
self.extract_embedding = extract_embedding
def forward(self, x):
x.stop_gradient = False
x = x.transpose([0, 3, 2, 1])
x = self.bn0(x)
x = x.transpose([0, 3, 2, 1])
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = x.mean(axis=3)
x = x.max(axis=2) + x.mean(axis=2)
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.fc1(x))
if self.extract_embedding:
output = F.dropout(x, p=0.5, training=self.training)
else:
output = F.sigmoid(self.fc_audioset(x))
return output
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class ConvBlock(nn.Layer):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2D(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias_attr=False)
self.conv2 = nn.Conv2D(in_channels=out_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias_attr=False)
self.bn1 = nn.BatchNorm2D(out_channels)
self.bn2 = nn.BatchNorm2D(out_channels)
def forward(self, x, pool_size=(2, 2), pool_type='avg'):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg+max':
x = F.avg_pool2d(x, kernel_size=pool_size) + F.max_pool2d(x, kernel_size=pool_size)
else:
raise Exception(
f'Pooling type of {pool_type} is not supported. It must be one of "max", "avg" and "avg+max".')
return x
class ConvBlock5x5(nn.Layer):
def __init__(self, in_channels, out_channels):
super(ConvBlock5x5, self).__init__()
self.conv1 = nn.Conv2D(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(5, 5),
stride=(1, 1),
padding=(2, 2),
bias_attr=False)
self.bn1 = nn.BatchNorm2D(out_channels)
def forward(self, x, pool_size=(2, 2), pool_type='avg'):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg+max':
x = F.avg_pool2d(x, kernel_size=pool_size) + F.max_pool2d(x, kernel_size=pool_size)
else:
raise Exception(
f'Pooling type of {pool_type} is not supported. It must be one of "max", "avg" and "avg+max".')
return x
import argparse
import glob
import os
import h5py
import numpy as np
import tqdm
import paddleaudio as pa
#from pylab import *
parser = argparse.ArgumentParser(description='wave2mel')
parser.add_argument('--wav_file', type=str, required=False, default='')
parser.add_argument('--wav_list', type=str, required=False, default='')
parser.add_argument('--wav_h5_file', type=str, required=False, default='')
parser.add_argument('--wav_h5_list', type=str, required=False, default='')
parser.add_argument('--output_folder', type=str, required=False, default='./')
parser.add_argument('--output_h5', type=bool, required=False, default=True)
parser.add_argument('--sample_rate', type=int, required=False, default=32000)
parser.add_argument('--window_size', type=int, required=False, default=1024)
parser.add_argument('--mel_bins', type=int, required=False, default=128)
parser.add_argument('--hop_length', type=int, required=False, default=640) #20ms
parser.add_argument('--fmin', type=int, required=False, default=50) #25ms
parser.add_argument('--fmax', type=int, required=False, default=16000) #25ms
parser.add_argument('--skip_existed', type=int, required=False, default=1) #25ms
args = parser.parse_args()
#args.wav_h5_file = '/ssd2/laiyongquan/audioset/h5/audioset_unblance_group28.h5'
assert not (args.wav_h5_file == '' and args.wav_h5_list == ''\
and args.wav_list == '' and args.wav_file == ''), 'one of wav_file,wav_list,\
wav_h5_file,wav_h5_list needs to specify'
if args.wav_h5_file != '':
h5_files = [args.wav_h5_file]
if args.wav_h5_list != '':
h5_files = open(args.wav_h5_list).read().split('\n')
h5_files = [h for h in h5_files if len(h.strip()) != 0]
dst_folder = args.output_folder
print(f'{len(h5_files)} h5 files listed')
for f in h5_files:
print(f'processing {f}')
dst_file = os.path.join(dst_folder, f.split('/')[-1])
print(f'target file {dst_file}')
if args.skip_existed != 0 and os.path.exists(dst_file):
print(f'skipped file {f}')
continue
assert not os.path.exists(dst_file), f'target file {dst_file} existed'
src_h5 = h5py.File(f)
dst_h5 = h5py.File(dst_file, "w")
for key in tqdm.tqdm(src_h5.keys()):
s = src_h5[key][:]
s = pa.depth_convert(s, 'float32')
# s = pa.resample(s,32000,args.sample_rate)
x = pa.features.mel_spect(s,
sample_rate=args.sample_rate,
window_size=args.window_size,
hop_length=args.hop_length,
mel_bins=args.mel_bins,
fmin=args.fmin,
fmax=args.fmax,
window='hann',
center=True,
pad_mode='reflect',
ref=1.0,
amin=1e-10,
top_db=None)
# figure(figsize=(8,8))
# imshow(x)
# show()
# print(x.shape)
dst_h5.create_dataset(key, data=x)
src_h5.close()
dst_h5.close()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -14,5 +14,6 @@
from .download import *
from .env import *
from .error import *
from .log import *
from .time import *
......@@ -13,26 +13,10 @@
# limitations under the License.
import numpy as np
import paddle
__all__ = ['randint', 'rand', 'weighted_sampling']
__all__ = ['ParameterError']
def randint(high, use_paddle=True):
if use_paddle:
return int(paddle.randint(0, high=high))
return int(np.random.randint(0, high=high))
def rand(use_paddle=True):
if use_paddle:
return float(paddle.rand((1, )))
return float(np.random.rand(1))
def weighted_sampling(weights):
n = len(weights)
w = np.cumsum(weights)
w = w / w[-1]
flag = rand() < w
return np.argwhere(flag)[0][0]
class ParameterError(Exception):
"""Exception class for Parameter checking"""
pass
colorama
colorlog
scipy
librosa
tqdm
numpy >= 1.15.0
scipy >= 1.0.0
resampy >= 0.2.2
soundfile >= 0.9.0
\ No newline at end of file
import setuptools
# set the version here
version = '0.1.0a'
with open("README.md", "r") as fh:
long_description = fh.read()
setuptools.setup(
name="PaddleAudio",
version="0.0.0",
name="PaddleAudio",
version=version,
author="",
author_email="",
description="PaddleAudio, in development",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/ranchlai/PaddleAudio",
packages=setuptools.find_packages(),
url="",
packages=setuptools.find_packages(exclude=["build*", "test*", "examples*"]),
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
python_requires='>=3.6',
install_requires=[
'numpy >= 1.15.0', 'scipy >= 1.0.0', 'resampy >= 0.2.2',
'soundfile >= 0.9.0'
],
extras_require={'dev': ['pytest>=3.7', 'librosa>=0.7.2']
} # for dev only, install: pip install -e .[dev]
)
# PaddleAudio Testing Guide
# Testing
First clone a version of the project by
```
git clone https://github.com/PaddlePaddle/models.git
```
Then install the project in your virtual environment.
```
cd models/PaddleAudio
python setup.py bdist_wheel
pip install -e .[dev]
```
The requirements for testing will be installed along with PaddleAudio.
Now run
```
pytest test
```
If it goes well, you will see outputs like these:
```
platform linux -- Python 3.7.10, pytest-6.2.4, py-1.10.0, pluggy-0.13.1
rootdir: ./models/PaddleAudio
plugins: hydra-core-1.0.6
collected 16 items
test/unit_test/test_backend.py ........... [ 68%]
test/unit_test/test_features.py ..... [100%]
==================================================== warnings summary ====================================================
.
.
.
-- Docs: https://docs.pytest.org/en/stable/warnings.html
============================================ 16 passed, 11 warnings in 6.76s =============================================
```
import librosa
import numpy as np
import paddleaudio
import pytest
import scipy
TEST_FILE = './test/data/test_audio.wav'
def relative_err(a, b, real=True):
"""compute relative error of two matrices or vectors"""
if real:
return np.sum((a - b)**2) / (EPS + np.sum(a**2) + np.sum(b**2))
else:
err = np.sum((a.real-b.real)**2) / \
(EPS+np.sum(a.real**2)+np.sum(b.real**2))
err += np.sum((a.imag-b.imag)**2) / \
(EPS+np.sum(a.imag**2)+np.sum(b.imag**2))
return err
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def load_audio():
x, r = librosa.load(TEST_FILE, sr=16000)
print(f'librosa: mean: {np.mean(x)}, std:{np.std(x)}')
return x, r
# start testing
x, r = load_audio()
EPS = 1e-8
def test_load():
s, r = paddleaudio.load(TEST_FILE, sr=16000)
assert r == 16000
assert s.dtype == 'float32'
s, r = paddleaudio.load(TEST_FILE,
sr=16000,
offset=1,
duration=2,
dtype='int16')
assert len(s) / r == 2.0
assert r == 16000
assert s.dtype == 'int16'
def test_depth_convert():
y = paddleaudio.depth_convert(x, 'int16')
assert len(y) == len(x)
assert y.dtype == 'int16'
assert np.max(y) <= 32767
assert np.min(y) >= -32768
assert np.std(y) > EPS
y = paddleaudio.depth_convert(x, 'int8')
assert len(y) == len(x)
assert y.dtype == 'int8'
assert np.max(y) <= 127
assert np.min(y) >= -128
assert np.std(y) > EPS
# test case for resample
rs_test_data = [
(32000, 'kaiser_fast'),
(16000, 'kaiser_fast'),
(8000, 'kaiser_fast'),
(32000, 'kaiser_best'),
(16000, 'kaiser_best'),
(8000, 'kaiser_best'),
(22050, 'kaiser_best'),
(44100, 'kaiser_best'),
]
@pytest.mark.parametrize('sr,mode', rs_test_data)
def test_resample(sr, mode):
y = paddleaudio.resample(x, 16000, sr, mode=mode)
factor = sr / 16000
err = relative_err(len(y), len(x) * factor)
print('err:', err)
assert err < EPS
def test_normalize():
y = paddleaudio.normalize(x, norm_type='linear', mul_factor=0.5)
assert np.max(y) < 0.5 + EPS
y = paddleaudio.normalize(x, norm_type='linear', mul_factor=2.0)
assert np.max(y) <= 2.0 + EPS
y = paddleaudio.normalize(x, norm_type='gaussian', mul_factor=1.0)
print('np.std(y):', np.std(y))
assert np.abs(np.std(y) - 1.0) < EPS
if __name__ == '__main__':
test_load()
test_depth_convert()
test_resample(22050, 'kaiser_fast')
test_normalize()
import librosa
import numpy as np
import paddleaudio as pa
import pytest
import scipy
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def load_audio():
x, r = librosa.load('./test/data/test_audio.wav')
#x,r = librosa.load('../data/test_audio.wav',sr=16000)
return x, r
## start testing
x, r = load_audio()
EPS = 1e-8
def relative_err(a, b, real=True):
"""compute relative error of two matrices or vectors"""
if real:
return np.sum((a - b)**2) / (EPS + np.sum(a**2) + np.sum(b**2))
else:
err = np.sum((a.real - b.real)**
2) / (EPS + np.sum(a.real**2) + np.sum(b.real**2))
err += np.sum((a.imag - b.imag)**
2) / (EPS + np.sum(a.imag**2) + np.sum(b.imag**2))
return err
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_melspectrogram():
a = pa.melspectrogram(
x,
window_size=512,
sr=16000,
hop_length=320,
n_mels=64,
fmin=50,
to_db=False,
)
b = librosa.feature.melspectrogram(x,
sr=16000,
n_fft=512,
win_length=512,
hop_length=320,
n_mels=64,
fmin=50)
assert relative_err(a, b) < EPS
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_melspectrogram_db():
a = pa.melspectrogram(x,
window_size=512,
sr=16000,
hop_length=320,
n_mels=64,
fmin=50,
to_db=True,
ref=1.0,
amin=1e-10,
top_db=None)
b = librosa.feature.melspectrogram(x,
sr=16000,
n_fft=512,
win_length=512,
hop_length=320,
n_mels=64,
fmin=50)
b = pa.power_to_db(b, ref=1.0, amin=1e-10, top_db=None)
assert relative_err(a, b) < EPS
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_stft():
a = pa.stft(x, n_fft=1024, hop_length=320, win_length=512)
b = librosa.stft(x, n_fft=1024, hop_length=320, win_length=512)
assert a.shape == b.shape
assert relative_err(a, b, real=False) < EPS
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_split_frames():
a = librosa.util.frame(x, frame_length=512, hop_length=320)
b = pa.split_frames(x, frame_length=512, hop_length=320)
assert relative_err(a, b) < EPS
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_mfcc():
kwargs = {
'window_size': 512,
'hop_length': 320,
'n_mels': 64,
'fmin': 50,
'to_db': False
}
a = pa.mfcc(
x,
#sample_rate=16000,
spect=None,
n_mfcc=20,
dct_type=2,
norm='ortho',
lifter=0,
**kwargs)
S = librosa.feature.melspectrogram(x,
sr=16000,
n_fft=512,
win_length=512,
hop_length=320,
n_mels=64,
fmin=50)
b = librosa.feature.mfcc(x,
sr=16000,
S=S,
n_mfcc=20,
dct_type=2,
norm='ortho',
lifter=0)
assert relative_err(a, b) < EPS
if __name__ == '__main__':
test_melspectrogram()
test_melspectrogram_db()
test_stft()
test_split_frames()
test_mfcc()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册