未验证 提交 fa34cdf1 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #754 from PaddlePaddle/develop

release 2.1.1
unset GREP_OPTIONS
# https://zhuanlan.zhihu.com/p/33050965
alias nvs='nvidia-smi'
alias his='history'
alias jobs='jobs -l'
alias ports='netstat -tulanp'
alias wget='wget -c'
## Colorize the grep command output for ease of use (good for log files)##
alias grep='grep --color=auto'
alias egrep='egrep --color=auto'
alias fgrep='fgrep --color=auto'
......@@ -10,8 +10,13 @@
.ipynb_checkpoints
*.npz
*.done
*.whl
tools/venv
tools/kenlm
tools/sox-14.4.2
tools/soxbindings
tools/montreal-forced-aligner/
tools/Montreal-Forced-Aligner/
*output/
......@@ -87,3 +87,9 @@ pull_request_rules:
actions:
label:
add: ["Docker"]
- name: "auto add label=Deployment"
conditions:
- files~=^speechnn/
actions:
label:
add: ["Deployment"]
{
"cells": [
{
"cell_type": "code",
"execution_count": 94,
"id": "matched-camera",
"metadata": {},
"outputs": [],
"source": [
"from nnAudio import Spectrogram\n",
"from scipy.io import wavfile\n",
"import torch\n",
"import soundfile as sf\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 95,
"id": "quarterly-solution",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[43 75 69 ... 7 6 3]\n",
"[43 75 69 ... 7 6 3]\n",
"[43 75 69 ... 7 6 3]\n"
]
}
],
"source": [
"import scipy.io.wavfile as wav\n",
"\n",
"rate,sig = wav.read('./BAC009S0764W0124.wav')\n",
"sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n",
"sample, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n",
"print(sig)\n",
"print(song)\n",
"print(sample)"
]
},
{
"cell_type": "code",
"execution_count": 96,
"id": "middle-salem",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"16000\n",
"[43 75 69 ... 7 6 3]\n",
"(83792,)\n",
"int16\n",
"sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n",
"STFT kernels created, time used = 0.2733 seconds\n",
"tensor([[[[-4.0940e+03, 1.2600e+04],\n",
" [ 8.5108e+03, -5.4930e+03],\n",
" [-3.3631e+03, -1.7904e+03],\n",
" ...,\n",
" [ 8.2279e+03, -9.3340e+03],\n",
" [-3.1990e+03, 2.0969e+03],\n",
" [-1.2669e+03, 4.4488e+03]],\n",
"\n",
" [[ 3.4886e+03, -9.9620e+03],\n",
" [-4.5364e+03, 4.1907e+02],\n",
" [ 2.5074e+03, 7.1339e+03],\n",
" ...,\n",
" [-5.4819e+03, 3.9258e+01],\n",
" [ 4.7221e+03, 6.5887e+01],\n",
" [ 9.6492e+02, -3.4386e+03]],\n",
"\n",
" [[-3.4947e+03, 9.2981e+03],\n",
" [-7.5164e+03, 8.1856e+02],\n",
" [-5.3766e+03, -9.0889e+03],\n",
" ...,\n",
" [ 1.4317e+03, 5.7447e+03],\n",
" [-3.1178e+03, 3.0740e+03],\n",
" [-3.4351e+03, 5.6900e+02]],\n",
"\n",
" ...,\n",
"\n",
" [[ 6.7112e+01, -4.5737e+00],\n",
" [-9.6295e+00, 3.5554e+01],\n",
" [ 1.8527e+00, -1.0491e+01],\n",
" ...,\n",
" [-1.1157e+01, 3.4423e+00],\n",
" [ 3.1193e+00, -4.4388e+00],\n",
" [-8.8242e+00, 8.0324e+00]],\n",
"\n",
" [[-6.5080e+01, 2.9543e+00],\n",
" [ 3.9992e+01, -1.3836e+01],\n",
" [-9.2803e+00, 1.0318e+01],\n",
" ...,\n",
" [ 4.2928e+00, 9.2397e+00],\n",
" [ 3.6642e+00, 9.4680e+00],\n",
" [ 4.8932e+00, -2.5199e+01]],\n",
"\n",
" [[ 4.7264e+01, -1.0721e+00],\n",
" [-6.0516e+00, -1.4589e+01],\n",
" [ 1.3127e+01, 1.4995e+00],\n",
" ...,\n",
" [ 1.7333e+01, -1.4380e+01],\n",
" [-3.6046e+00, -6.1019e+00],\n",
" [ 1.3321e+01, 2.3184e+01]]]])\n"
]
}
],
"source": [
"sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n",
"print(sr)\n",
"print(song)\n",
"print(song.shape)\n",
"print(song.dtype)\n",
"x = song\n",
"x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n",
"\n",
"spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n",
" window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n",
" fmin=50,fmax=8000, sr=sr) # Initializing the model\n",
"\n",
"spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n",
"print(spec)"
]
},
{
"cell_type": "code",
"execution_count": 97,
"id": "finished-sterling",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"16000\n",
"[43 75 69 ... 7 6 3]\n",
"(83792,)\n",
"int16\n",
"True\n",
"sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n",
"STFT kernels created, time used = 0.2001 seconds\n",
"torch.Size([1, 1025, 164, 2])\n",
"tensor([[[[-4.0940e+03, 1.2600e+04],\n",
" [ 8.5108e+03, -5.4930e+03],\n",
" [-3.3631e+03, -1.7904e+03],\n",
" ...,\n",
" [ 8.2279e+03, -9.3340e+03],\n",
" [-3.1990e+03, 2.0969e+03],\n",
" [-1.2669e+03, 4.4488e+03]],\n",
"\n",
" [[ 3.4886e+03, -9.9620e+03],\n",
" [-4.5364e+03, 4.1907e+02],\n",
" [ 2.5074e+03, 7.1339e+03],\n",
" ...,\n",
" [-5.4819e+03, 3.9258e+01],\n",
" [ 4.7221e+03, 6.5887e+01],\n",
" [ 9.6492e+02, -3.4386e+03]],\n",
"\n",
" [[-3.4947e+03, 9.2981e+03],\n",
" [-7.5164e+03, 8.1856e+02],\n",
" [-5.3766e+03, -9.0889e+03],\n",
" ...,\n",
" [ 1.4317e+03, 5.7447e+03],\n",
" [-3.1178e+03, 3.0740e+03],\n",
" [-3.4351e+03, 5.6900e+02]],\n",
"\n",
" ...,\n",
"\n",
" [[ 6.7112e+01, -4.5737e+00],\n",
" [-9.6295e+00, 3.5554e+01],\n",
" [ 1.8527e+00, -1.0491e+01],\n",
" ...,\n",
" [-1.1157e+01, 3.4423e+00],\n",
" [ 3.1193e+00, -4.4388e+00],\n",
" [-8.8242e+00, 8.0324e+00]],\n",
"\n",
" [[-6.5080e+01, 2.9543e+00],\n",
" [ 3.9992e+01, -1.3836e+01],\n",
" [-9.2803e+00, 1.0318e+01],\n",
" ...,\n",
" [ 4.2928e+00, 9.2397e+00],\n",
" [ 3.6642e+00, 9.4680e+00],\n",
" [ 4.8932e+00, -2.5199e+01]],\n",
"\n",
" [[ 4.7264e+01, -1.0721e+00],\n",
" [-6.0516e+00, -1.4589e+01],\n",
" [ 1.3127e+01, 1.4995e+00],\n",
" ...,\n",
" [ 1.7333e+01, -1.4380e+01],\n",
" [-3.6046e+00, -6.1019e+00],\n",
" [ 1.3321e+01, 2.3184e+01]]]])\n",
"True\n"
]
}
],
"source": [
"wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n",
"print(sr)\n",
"print(wav)\n",
"print(wav.shape)\n",
"print(wav.dtype)\n",
"print(np.allclose(wav, song))\n",
"\n",
"x = wav\n",
"x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n",
"\n",
"spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n",
" window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n",
" fmin=50,fmax=8000, sr=sr) # Initializing the model\n",
"\n",
"wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n",
"print(wav_spec.shape)\n",
"print(wav_spec)\n",
"print(np.allclose(wav_spec, spec))"
]
},
{
"cell_type": "code",
"execution_count": 98,
"id": "running-technology",
"metadata": {},
"outputs": [],
"source": [
"import decimal\n",
"\n",
"import numpy\n",
"import math\n",
"import logging\n",
"def round_half_up(number):\n",
" return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP))\n",
"\n",
"\n",
"def rolling_window(a, window, step=1):\n",
" # http://ellisvalentiner.com/post/2017-03-21-np-strides-trick\n",
" shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)\n",
" strides = a.strides + (a.strides[-1],)\n",
" return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)[::step]\n",
"\n",
"\n",
"def framesig(sig, frame_len, frame_step, dither=1.0, preemph=0.97, remove_dc_offset=True, wintype='hamming', stride_trick=True):\n",
" \"\"\"Frame a signal into overlapping frames.\n",
"\n",
" :param sig: the audio signal to frame.\n",
" :param frame_len: length of each frame measured in samples.\n",
" :param frame_step: number of samples after the start of the previous frame that the next frame should begin.\n",
" :param winfunc: the analysis window to apply to each frame. By default no window is applied.\n",
" :param stride_trick: use stride trick to compute the rolling window and window multiplication faster\n",
" :returns: an array of frames. Size is NUMFRAMES by frame_len.\n",
" \"\"\"\n",
" slen = len(sig)\n",
" frame_len = int(round_half_up(frame_len))\n",
" frame_step = int(round_half_up(frame_step))\n",
" if slen <= frame_len:\n",
" numframes = 1\n",
" else:\n",
" numframes = 1 + (( slen - frame_len) // frame_step)\n",
"\n",
" # check kaldi/src/feat/feature-window.h\n",
" padsignal = sig[:(numframes-1)*frame_step+frame_len]\n",
" if wintype is 'povey':\n",
" win = numpy.empty(frame_len)\n",
" for i in range(frame_len):\n",
" win[i] = (0.5-0.5*numpy.cos(2*numpy.pi/(frame_len-1)*i))**0.85 \n",
" else: # the hamming window\n",
" win = numpy.hamming(frame_len)\n",
" \n",
" if stride_trick:\n",
" frames = rolling_window(padsignal, window=frame_len, step=frame_step)\n",
" else:\n",
" indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(\n",
" numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T\n",
" indices = numpy.array(indices, dtype=numpy.int32)\n",
" frames = padsignal[indices]\n",
" win = numpy.tile(win, (numframes, 1))\n",
" \n",
" frames = frames.astype(numpy.float32)\n",
" raw_frames = numpy.zeros(frames.shape)\n",
" for frm in range(frames.shape[0]):\n",
" raw_frames[frm,:] = frames[frm,:]\n",
" frames[frm,:] = do_dither(frames[frm,:], dither) # dither\n",
" frames[frm,:] = do_remove_dc_offset(frames[frm,:]) # remove dc offset\n",
" # raw_frames[frm,:] = frames[frm,:]\n",
" frames[frm,:] = do_preemphasis(frames[frm,:], preemph) # preemphasize\n",
"\n",
" return frames * win, raw_frames\n",
"\n",
"\n",
"def magspec(frames, NFFT):\n",
" \"\"\"Compute the magnitude spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n",
"\n",
" :param frames: the array of frames. Each row is a frame.\n",
" :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n",
" :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the magnitude spectrum of the corresponding frame.\n",
" \"\"\"\n",
" if numpy.shape(frames)[1] > NFFT:\n",
" logging.warn(\n",
" 'frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.',\n",
" numpy.shape(frames)[1], NFFT)\n",
" complex_spec = numpy.fft.rfft(frames, NFFT)\n",
" return numpy.absolute(complex_spec)\n",
"\n",
"\n",
"def powspec(frames, NFFT):\n",
" \"\"\"Compute the power spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n",
"\n",
" :param frames: the array of frames. Each row is a frame.\n",
" :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n",
" :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the power spectrum of the corresponding frame.\n",
" \"\"\"\n",
" return numpy.square(magspec(frames, NFFT))\n",
"\n",
"\n",
"def do_dither(signal, dither_value=1.0):\n",
" signal += numpy.random.normal(size=signal.shape) * dither_value\n",
" return signal\n",
" \n",
"def do_remove_dc_offset(signal):\n",
" signal -= numpy.mean(signal)\n",
" return signal\n",
"\n",
"def do_preemphasis(signal, coeff=0.97):\n",
" \"\"\"perform preemphasis on the input signal.\n",
"\n",
" :param signal: The signal to filter.\n",
" :param coeff: The preemphasis coefficient. 0 is no filter, default is 0.95.\n",
" :returns: the filtered signal.\n",
" \"\"\"\n",
" return numpy.append((1-coeff)*signal[0], signal[1:] - coeff * signal[:-1])"
]
},
{
"cell_type": "code",
"execution_count": 99,
"id": "ignored-retreat",
"metadata": {},
"outputs": [],
"source": [
"def fbank(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n",
" wintype='hamming'):\n",
" highfreq= highfreq or samplerate/2\n",
" frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n",
" spec = magspec(frames, nfft) # nearly the same until this part\n",
" rspec = magspec(raw_frames, nfft)\n",
" return spec, rspec\n",
"\n",
"\n",
"\n",
"def frames(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n",
" wintype='hamming'):\n",
" highfreq= highfreq or samplerate/2\n",
" frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n",
" return raw_frames"
]
},
{
"cell_type": "code",
"execution_count": 100,
"id": "federal-teacher",
"metadata": {},
"outputs": [],
"source": [
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.nn.functional import conv1d, conv2d, fold\n",
"import scipy # used only in CFP\n",
"\n",
"import numpy as np\n",
"from time import time\n",
"\n",
"def pad_center(data, size, axis=-1, **kwargs):\n",
"\n",
" kwargs.setdefault('mode', 'constant')\n",
"\n",
" n = data.shape[axis]\n",
"\n",
" lpad = int((size - n) // 2)\n",
"\n",
" lengths = [(0, 0)] * data.ndim\n",
" lengths[axis] = (lpad, int(size - n - lpad))\n",
"\n",
" if lpad < 0:\n",
" raise ParameterError(('Target size ({:d}) must be '\n",
" 'at least input size ({:d})').format(size, n))\n",
"\n",
" return np.pad(data, lengths, **kwargs)\n",
"\n",
"\n",
"\n",
"sz_float = 4 # size of a float\n",
"epsilon = 10e-8 # fudge factor for normalization\n",
"\n",
"def create_fourier_kernels(n_fft, win_length=None, freq_bins=None, fmin=50,fmax=6000, sr=44100,\n",
" freq_scale='linear', window='hann', verbose=True):\n",
"\n",
" if freq_bins==None: freq_bins = n_fft//2+1\n",
" if win_length==None: win_length = n_fft\n",
"\n",
" s = np.arange(0, n_fft, 1.)\n",
" wsin = np.empty((freq_bins,1,n_fft))\n",
" wcos = np.empty((freq_bins,1,n_fft))\n",
" start_freq = fmin\n",
" end_freq = fmax\n",
" bins2freq = []\n",
" binslist = []\n",
"\n",
" # num_cycles = start_freq*d/44000.\n",
" # scaling_ind = np.log(end_freq/start_freq)/k\n",
"\n",
" # Choosing window shape\n",
"\n",
" #window_mask = get_window(window, int(win_length), fftbins=True)\n",
" window_mask = np.hamming(int(win_length))\n",
" window_mask = pad_center(window_mask, n_fft)\n",
"\n",
" if freq_scale == 'linear':\n",
" if verbose==True:\n",
" print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n",
" f\"get a valid freq range\")\n",
" \n",
" start_bin = start_freq*n_fft/sr\n",
" scaling_ind = (end_freq-start_freq)*(n_fft/sr)/freq_bins\n",
"\n",
" for k in range(freq_bins): # Only half of the bins contain useful info\n",
" # print(\"linear freq = {}\".format((k*scaling_ind+start_bin)*sr/n_fft))\n",
" bins2freq.append((k*scaling_ind+start_bin)*sr/n_fft)\n",
" binslist.append((k*scaling_ind+start_bin))\n",
" wsin[k,0,:] = np.sin(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n",
" wcos[k,0,:] = np.cos(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n",
"\n",
" elif freq_scale == 'log':\n",
" if verbose==True:\n",
" print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n",
" f\"get a valid freq range\")\n",
" start_bin = start_freq*n_fft/sr\n",
" scaling_ind = np.log(end_freq/start_freq)/freq_bins\n",
"\n",
" for k in range(freq_bins): # Only half of the bins contain useful info\n",
" # print(\"log freq = {}\".format(np.exp(k*scaling_ind)*start_bin*sr/n_fft))\n",
" bins2freq.append(np.exp(k*scaling_ind)*start_bin*sr/n_fft)\n",
" binslist.append((np.exp(k*scaling_ind)*start_bin))\n",
" wsin[k,0,:] = np.sin(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n",
" wcos[k,0,:] = np.cos(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n",
"\n",
" elif freq_scale == 'no':\n",
" for k in range(freq_bins): # Only half of the bins contain useful info\n",
" bins2freq.append(k*sr/n_fft)\n",
" binslist.append(k)\n",
" wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n",
" wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n",
" else:\n",
" print(\"Please select the correct frequency scale, 'linear' or 'log'\")\n",
" return wsin.astype(np.float32),wcos.astype(np.float32), bins2freq, binslist, window_mask.astype(np.float32)\n",
"\n",
"\n",
"\n",
"def broadcast_dim(x):\n",
" \"\"\"\n",
" Auto broadcast input so that it can fits into a Conv1d\n",
" \"\"\"\n",
"\n",
" if x.dim() == 2:\n",
" x = x[:, None, :]\n",
" elif x.dim() == 1:\n",
" # If nn.DataParallel is used, this broadcast doesn't work\n",
" x = x[None, None, :]\n",
" elif x.dim() == 3:\n",
" pass\n",
" else:\n",
" raise ValueError(\"Only support input with shape = (batch, len) or shape = (len)\")\n",
" return x\n",
"\n",
"\n",
"\n",
"### --------------------------- Spectrogram Classes ---------------------------###\n",
"class STFT(torch.nn.Module):\n",
"\n",
" def __init__(self, n_fft=2048, win_length=None, freq_bins=None, hop_length=None, window='hann',\n",
" freq_scale='no', center=True, pad_mode='reflect', iSTFT=False,\n",
" fmin=50, fmax=6000, sr=22050, trainable=False,\n",
" output_format=\"Complex\", verbose=True):\n",
"\n",
" super().__init__()\n",
"\n",
" # Trying to make the default setting same as librosa\n",
" if win_length==None: win_length = n_fft\n",
" if hop_length==None: hop_length = int(win_length // 4)\n",
"\n",
" self.output_format = output_format\n",
" self.trainable = trainable\n",
" self.stride = hop_length\n",
" self.center = center\n",
" self.pad_mode = pad_mode\n",
" self.n_fft = n_fft\n",
" self.freq_bins = freq_bins\n",
" self.trainable = trainable\n",
" self.pad_amount = self.n_fft // 2\n",
" self.window = window\n",
" self.win_length = win_length\n",
" self.iSTFT = iSTFT\n",
" self.trainable = trainable\n",
" start = time()\n",
"\n",
"\n",
"\n",
" # Create filter windows for stft\n",
" kernel_sin, kernel_cos, self.bins2freq, self.bin_list, window_mask = create_fourier_kernels(n_fft,\n",
" win_length=win_length,\n",
" freq_bins=freq_bins,\n",
" window=window,\n",
" freq_scale=freq_scale,\n",
" fmin=fmin,\n",
" fmax=fmax,\n",
" sr=sr,\n",
" verbose=verbose)\n",
"\n",
"\n",
" kernel_sin = torch.tensor(kernel_sin, dtype=torch.float)\n",
" kernel_cos = torch.tensor(kernel_cos, dtype=torch.float)\n",
" \n",
" # In this way, the inverse kernel and the forward kernel do not share the same memory...\n",
" kernel_sin_inv = torch.cat((kernel_sin, -kernel_sin[1:-1].flip(0)), 0)\n",
" kernel_cos_inv = torch.cat((kernel_cos, kernel_cos[1:-1].flip(0)), 0)\n",
" \n",
" if iSTFT:\n",
" self.register_buffer('kernel_sin_inv', kernel_sin_inv.unsqueeze(-1))\n",
" self.register_buffer('kernel_cos_inv', kernel_cos_inv.unsqueeze(-1))\n",
"\n",
" # Applying window functions to the Fourier kernels\n",
" if window:\n",
" window_mask = torch.tensor(window_mask)\n",
" wsin = kernel_sin * window_mask\n",
" wcos = kernel_cos * window_mask\n",
" else:\n",
" wsin = kernel_sin\n",
" wcos = kernel_cos\n",
" \n",
" if self.trainable==False:\n",
" self.register_buffer('wsin', wsin)\n",
" self.register_buffer('wcos', wcos) \n",
" \n",
" if self.trainable==True:\n",
" wsin = torch.nn.Parameter(wsin, requires_grad=self.trainable)\n",
" wcos = torch.nn.Parameter(wcos, requires_grad=self.trainable) \n",
" self.register_parameter('wsin', wsin)\n",
" self.register_parameter('wcos', wcos) \n",
" \n",
" # Prepare the shape of window mask so that it can be used later in inverse\n",
" # self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1))\n",
" \n",
" if verbose==True:\n",
" print(\"STFT kernels created, time used = {:.4f} seconds\".format(time()-start))\n",
" else:\n",
" pass\n",
"\n",
" def forward(self, x, output_format=None):\n",
" \"\"\"\n",
" Convert a batch of waveforms to spectrograms.\n",
" \n",
" Parameters\n",
" ----------\n",
" x : torch tensor\n",
" Input signal should be in either of the following shapes.\\n\n",
" 1. ``(len_audio)``\\n\n",
" 2. ``(num_audio, len_audio)``\\n\n",
" 3. ``(num_audio, 1, len_audio)``\n",
" It will be automatically broadcast to the right shape\n",
" \n",
" output_format : str\n",
" Control the type of spectrogram to be return. Can be either ``Magnitude`` or ``Complex`` or ``Phase``.\n",
" Default value is ``Complex``. \n",
" \n",
" \"\"\"\n",
" output_format = output_format or self.output_format\n",
" self.num_samples = x.shape[-1]\n",
" \n",
" x = broadcast_dim(x)\n",
" if self.center:\n",
" if self.pad_mode == 'constant':\n",
" padding = nn.ConstantPad1d(self.pad_amount, 0)\n",
"\n",
" elif self.pad_mode == 'reflect':\n",
" if self.num_samples < self.pad_amount:\n",
" raise AssertionError(\"Signal length shorter than reflect padding length (n_fft // 2).\")\n",
" padding = nn.ReflectionPad1d(self.pad_amount)\n",
"\n",
" x = padding(x)\n",
" spec_imag = conv1d(x, self.wsin, stride=self.stride)\n",
" spec_real = conv1d(x, self.wcos, stride=self.stride) # Doing STFT by using conv1d\n",
"\n",
" # remove redundant parts\n",
" spec_real = spec_real[:, :self.freq_bins, :]\n",
" spec_imag = spec_imag[:, :self.freq_bins, :]\n",
"\n",
" if output_format=='Magnitude':\n",
" spec = spec_real.pow(2) + spec_imag.pow(2)\n",
" if self.trainable==True:\n",
" return torch.sqrt(spec+1e-8) # prevent Nan gradient when sqrt(0) due to output=0\n",
" else:\n",
" return torch.sqrt(spec)\n",
"\n",
" elif output_format=='Complex':\n",
" return torch.stack((spec_real,-spec_imag), -1) # Remember the minus sign for imaginary part\n",
"\n",
" elif output_format=='Phase':\n",
" return torch.atan2(-spec_imag+0.0,spec_real) # +0.0 removes -0.0 elements, which leads to error in calculating phase\n",
"\n",
" def inverse(self, X, onesided=True, length=None, refresh_win=True):\n",
" \"\"\"\n",
" This function is same as the :func:`~nnAudio.Spectrogram.iSTFT` class, \n",
" which is to convert spectrograms back to waveforms. \n",
" It only works for the complex value spectrograms. If you have the magnitude spectrograms,\n",
" please use :func:`~nnAudio.Spectrogram.Griffin_Lim`. \n",
" \n",
" Parameters\n",
" ----------\n",
" onesided : bool\n",
" If your spectrograms only have ``n_fft//2+1`` frequency bins, please use ``onesided=True``,\n",
" else use ``onesided=False``\n",
"\n",
" length : int\n",
" To make sure the inverse STFT has the same output length of the original waveform, please\n",
" set `length` as your intended waveform length. By default, ``length=None``,\n",
" which will remove ``n_fft//2`` samples from the start and the end of the output.\n",
" \n",
" refresh_win : bool\n",
" Recalculating the window sum square. If you have an input with fixed number of timesteps,\n",
" you can increase the speed by setting ``refresh_win=False``. Else please keep ``refresh_win=True``\n",
" \n",
" \n",
" \"\"\"\n",
" if (hasattr(self, 'kernel_sin_inv') != True) or (hasattr(self, 'kernel_cos_inv') != True):\n",
" raise NameError(\"Please activate the iSTFT module by setting `iSTFT=True` if you want to use `inverse`\") \n",
" \n",
" assert X.dim()==4 , \"Inverse iSTFT only works for complex number,\" \\\n",
" \"make sure our tensor is in the shape of (batch, freq_bins, timesteps, 2).\"\\\n",
" \"\\nIf you have a magnitude spectrogram, please consider using Griffin-Lim.\"\n",
" if onesided:\n",
" X = extend_fbins(X) # extend freq\n",
"\n",
" \n",
" X_real, X_imag = X[:, :, :, 0], X[:, :, :, 1]\n",
"\n",
" # broadcast dimensions to support 2D convolution\n",
" X_real_bc = X_real.unsqueeze(1)\n",
" X_imag_bc = X_imag.unsqueeze(1)\n",
" a1 = conv2d(X_real_bc, self.kernel_cos_inv, stride=(1,1))\n",
" b2 = conv2d(X_imag_bc, self.kernel_sin_inv, stride=(1,1))\n",
" \n",
" # compute real and imag part. signal lies in the real part\n",
" real = a1 - b2\n",
" real = real.squeeze(-2)*self.window_mask\n",
"\n",
" # Normalize the amplitude with n_fft\n",
" real /= (self.n_fft)\n",
"\n",
" # Overlap and Add algorithm to connect all the frames\n",
" real = overlap_add(real, self.stride)\n",
" \n",
" # Prepare the window sumsqure for division\n",
" # Only need to create this window once to save time\n",
" # Unless the input spectrograms have different time steps\n",
" if hasattr(self, 'w_sum')==False or refresh_win==True:\n",
" self.w_sum = torch_window_sumsquare(self.window_mask.flatten(), X.shape[2], self.stride, self.n_fft).flatten()\n",
" self.nonzero_indices = (self.w_sum>1e-10) \n",
" else:\n",
" pass\n",
" real[:, self.nonzero_indices] = real[:,self.nonzero_indices].div(self.w_sum[self.nonzero_indices])\n",
" # Remove padding\n",
" if length is None: \n",
" if self.center:\n",
" real = real[:, self.pad_amount:-self.pad_amount]\n",
"\n",
" else:\n",
" if self.center:\n",
" real = real[:, self.pad_amount:self.pad_amount + length] \n",
" else:\n",
" real = real[:, :length] \n",
" \n",
" return real\n",
" \n",
" def extra_repr(self) -> str:\n",
" return 'n_fft={}, Fourier Kernel size={}, iSTFT={}, trainable={}'.format(\n",
" self.n_fft, (*self.wsin.shape,), self.iSTFT, self.trainable\n",
" ) "
]
},
{
"cell_type": "code",
"execution_count": 128,
"id": "unusual-baker",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"16000\n",
"(83792,)\n",
"sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n",
"STFT kernels created, time used = 0.0153 seconds\n",
"torch.Size([521, 257])\n",
"(522, 257)\n",
"[[5.84560000e+04 2.55260664e+04 9.83611035e+03 ... 7.80710554e+00\n",
" 2.32206573e+01 1.90274487e+01]\n",
" [1.35420000e+04 3.47535000e+04 1.51204707e+04 ... 1.69094101e+02\n",
" 1.80534729e+02 1.84179596e+02]\n",
" [3.47560000e+04 2.83094609e+04 8.20204883e+03 ... 1.02080307e+02\n",
" 1.21321175e+02 1.08345497e+02]\n",
" ...\n",
" [9.36700000e+03 2.86213008e+04 1.41182402e+04 ... 1.19344498e+02\n",
" 1.25670158e+02 1.20691467e+02]\n",
" [2.87510000e+04 2.04348242e+04 8.76390625e+03 ... 9.74485092e+01\n",
" 9.01831894e+01 9.84055099e+01]\n",
" [4.45240000e+04 8.93593262e+03 4.39246826e+03 ... 6.16300154e+00\n",
" 8.94473553e+00 9.61348629e+00]]\n",
"[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n",
" 2.40645984e+01 2.20000000e+01]\n",
" [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n",
" 1.18775735e+02 1.62000000e+02]\n",
" [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n",
" 9.57810428e+01 1.42000000e+02]\n",
" ...\n",
" [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n",
" 7.84053656e+01 9.00000000e+01]\n",
" [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n",
" 5.13101944e+01 3.50000000e+01]\n",
" [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n",
" 6.36197377e+01 4.40000000e+01]]\n"
]
}
],
"source": [
"wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n",
"print(sr)\n",
"print(wav.shape)\n",
"\n",
"x = wav\n",
"x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n",
"\n",
"spec_layer = STFT(n_fft=512, win_length=400, hop_length=160,\n",
" window='', freq_scale='linear', center=False, pad_mode='constant',\n",
" fmin=0, fmax=8000, sr=sr, output_format='Magnitude')\n",
"wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n",
"wav_spec = wav_spec[0].T\n",
"print(wav_spec.shape)\n",
"\n",
"\n",
"spec, rspec = fbank(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n",
" dither=0.0,remove_dc_offset=False, preemph=1.0, \n",
" wintype='hamming')\n",
"print(spec.shape)\n",
"\n",
"print(wav_spec.numpy())\n",
"print(rspec)\n",
"# print(spec)\n",
"\n",
"# spec, rspec = fbank(wav, samplerate=16000,winlen=0.032,winstep=0.01,\n",
"# nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n",
"# dither=0.0,remove_dc_offset=False, preemph=1.0, \n",
"# wintype='hamming')\n",
"# print(rspec)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "white-istanbul",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 129,
"id": "modern-rescue",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0. 0.11697778 0.41317591 0.75 0.96984631 0.96984631\n",
" 0.75 0.41317591 0.11697778 0. ]\n"
]
},
{
"data": {
"text/plain": [
"array([0. , 0.0954915, 0.3454915, 0.6545085, 0.9045085, 1. ,\n",
" 0.9045085, 0.6545085, 0.3454915, 0.0954915])"
]
},
"execution_count": 129,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(np.hanning(10))\n",
"from scipy.signal import get_window\n",
"get_window('hann', 10, fftbins=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "professional-journalism",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 153,
"id": "involved-motion",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(522, 400)\n",
"[[ 43. 75. 69. ... 46. 46. 45.]\n",
" [ 210. 215. 216. ... -86. -89. -91.]\n",
" [ 128. 128. 128. ... -154. -151. -151.]\n",
" ...\n",
" [ -60. -61. -61. ... 112. 109. 110.]\n",
" [ 20. 22. 24. ... 91. 87. 87.]\n",
" [ 111. 107. 108. ... -6. -4. -8.]]\n",
"torch.Size([1, 1, 83792])\n",
"torch.Size([400, 1, 512])\n",
"torch.Size([1, 400, 521])\n",
"conv frame tensor([[ 43., 75., 69., ..., 46., 46., 45.],\n",
" [ 210., 215., 216., ..., -86., -89., -91.],\n",
" [ 128., 128., 128., ..., -154., -151., -151.],\n",
" ...,\n",
" [-143., -141., -142., ..., 96., 101., 101.],\n",
" [ -60., -61., -61., ..., 112., 109., 110.],\n",
" [ 20., 22., 24., ..., 91., 87., 87.]])\n",
"xx [[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n",
" 2.4064583e+01 2.2000000e+01]\n",
" [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n",
" 1.1877571e+02 1.6200000e+02]\n",
" [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n",
" 9.5781029e+01 1.4200000e+02]\n",
" ...\n",
" [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n",
" 9.1511757e+01 1.1500000e+02]\n",
" [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n",
" 7.8405365e+01 9.0000000e+01]\n",
" [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n",
" 5.1310158e+01 3.5000000e+01]]\n",
"torch.Size([521, 257])\n",
"yy [[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n",
" 2.40645984e+01 2.20000000e+01]\n",
" [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n",
" 1.18775735e+02 1.62000000e+02]\n",
" [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n",
" 9.57810428e+01 1.42000000e+02]\n",
" ...\n",
" [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n",
" 9.15117270e+01 1.15000000e+02]\n",
" [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n",
" 7.84053656e+01 9.00000000e+01]\n",
" [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n",
" 5.13101944e+01 3.50000000e+01]]\n",
"yy (522, 257)\n",
"[[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n",
" 2.4064583e+01 2.2000000e+01]\n",
" [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n",
" 1.1877571e+02 1.6200000e+02]\n",
" [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n",
" 9.5781029e+01 1.4200000e+02]\n",
" ...\n",
" [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n",
" 9.1511757e+01 1.1500000e+02]\n",
" [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n",
" 7.8405365e+01 9.0000000e+01]\n",
" [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n",
" 5.1310158e+01 3.5000000e+01]]\n",
"[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n",
" 2.40645984e+01 2.20000000e+01]\n",
" [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n",
" 1.18775735e+02 1.62000000e+02]\n",
" [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n",
" 9.57810428e+01 1.42000000e+02]\n",
" ...\n",
" [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n",
" 9.15117270e+01 1.15000000e+02]\n",
" [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n",
" 7.84053656e+01 9.00000000e+01]\n",
" [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n",
" 5.13101944e+01 3.50000000e+01]]\n",
"False\n"
]
}
],
"source": [
"f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n",
" dither=0.0,remove_dc_offset=False, preemph=1.0, \n",
" wintype='hamming')\n",
"print(f.shape)\n",
"print(f)\n",
"\n",
"n_fft=512\n",
"freq_bins = n_fft//2+1\n",
"s = np.arange(0, n_fft, 1.)\n",
"wsin = np.empty((freq_bins,1,n_fft))\n",
"wcos = np.empty((freq_bins,1,n_fft))\n",
"for k in range(freq_bins): # Only half of the bins contain useful info\n",
" wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n",
" wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n",
"\n",
"\n",
"wsin = np.empty((n_fft,1,n_fft))\n",
"wcos = np.empty((n_fft,1,n_fft))\n",
"for k in range(n_fft): # Only half of the bins contain useful info\n",
" wsin[k,0,:] = np.eye(n_fft, n_fft)[k]\n",
" wcos[k,0,:] = np.eye(n_fft, n_fft)[k]\n",
" \n",
" \n",
"wsin = np.empty((400,1,n_fft))\n",
"wcos = np.empty((400,1,n_fft))\n",
"for k in range(400): # Only half of the bins contain useful info\n",
" wsin[k,0,:] = np.eye(400, n_fft)[k]\n",
" wcos[k,0,:] = np.eye(400, n_fft)[k]\n",
" \n",
"\n",
" \n",
"x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n",
"x = x[None, None, :]\n",
"print(x.size())\n",
"kernel_sin = torch.tensor(wsin, dtype=torch.float)\n",
"kernel_cos = torch.tensor(wcos, dtype=torch.float)\n",
"print(kernel_sin.size())\n",
"\n",
"from torch.nn.functional import conv1d, conv2d, fold\n",
"spec_imag = conv1d(x, kernel_sin, stride=160)\n",
"spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n",
"\n",
"print(spec_imag.size())\n",
"print(\"conv frame\", spec_imag[0].T)\n",
"# print(spec_imag[0].T[:, :400])\n",
"\n",
"# remove redundant parts\n",
"# spec_real = spec_real[:, :freq_bins, :]\n",
"# spec_imag = spec_imag[:, :freq_bins, :]\n",
"# spec = spec_real.pow(2) + spec_imag.pow(2)\n",
"# spec = torch.sqrt(spec)\n",
"# print(spec)\n",
"\n",
"\n",
"\n",
"s = np.arange(0, 512, 1.)\n",
"# s = s[::-1]\n",
"wsin = np.empty((freq_bins, 400))\n",
"wcos = np.empty((freq_bins, 400))\n",
"for k in range(freq_bins): # Only half of the bins contain useful info\n",
" wsin[k,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n",
" wcos[k,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n",
"\n",
"spec_real = torch.mm(spec_imag[0].T, torch.tensor(wcos, dtype=torch.float).T)\n",
"spec_imag = torch.mm(spec_imag[0].T, torch.tensor(wsin, dtype=torch.float).T)\n",
"\n",
"\n",
"# remove redundant parts\n",
"spec = spec_real.pow(2) + spec_imag.pow(2)\n",
"spec = torch.sqrt(spec)\n",
"\n",
"print('xx', spec.numpy())\n",
"print(spec.size())\n",
"print('yy', rspec[:521, :])\n",
"print('yy', rspec.shape)\n",
"\n",
"\n",
"x = spec.numpy()\n",
"y = rspec[:-1, :]\n",
"print(x)\n",
"print(y)\n",
"print(np.allclose(x, y))"
]
},
{
"cell_type": "code",
"execution_count": 160,
"id": "mathematical-traffic",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([257, 1, 400])\n",
"tensor([[[5.8976e+04, 2.9266e+04, 1.9630e+04, ..., 1.6772e+04,\n",
" 3.8693e+04, 3.1020e+04],\n",
" [2.5101e+04, 2.7298e+04, 2.8117e+04, ..., 2.1323e+04,\n",
" 1.3598e+04, 1.5920e+04],\n",
" [8.5960e+03, 4.7724e+03, 5.2880e+03, ..., 4.0608e+02,\n",
" 6.7707e+03, 4.3020e+03],\n",
" ...,\n",
" [2.0282e+01, 6.6927e+01, 2.8501e+01, ..., 2.6012e+01,\n",
" 6.1071e+01, 5.3685e+01],\n",
" [2.4065e+01, 1.1878e+02, 9.5781e+01, ..., 7.8405e+01,\n",
" 5.1310e+01, 6.3620e+01],\n",
" [2.2000e+01, 1.6200e+02, 1.4200e+02, ..., 9.0000e+01,\n",
" 3.5000e+01, 4.4000e+01]]])\n",
"[[5.8976000e+04 2.5100672e+04 8.5960391e+03 ... 2.0281828e+01\n",
" 2.4064537e+01 2.2000000e+01]\n",
" [2.9266000e+04 2.7298107e+04 4.7724243e+03 ... 6.6926659e+01\n",
" 1.1877571e+02 1.6200000e+02]\n",
" [1.9630000e+04 2.8117475e+04 5.2880312e+03 ... 2.8501148e+01\n",
" 9.5781006e+01 1.4200000e+02]\n",
" ...\n",
" [1.6772000e+04 2.1322793e+04 4.0607657e+02 ... 2.6011934e+01\n",
" 7.8405350e+01 9.0000000e+01]\n",
" [3.8693000e+04 1.3598203e+04 6.7706841e+03 ... 6.1070808e+01\n",
" 5.1310150e+01 3.5000000e+01]\n",
" [3.1020000e+04 1.5920403e+04 4.3019902e+03 ... 5.3685162e+01\n",
" 6.3619797e+01 4.4000000e+01]]\n",
"[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n",
" 2.40645984e+01 2.20000000e+01]\n",
" [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n",
" 1.18775735e+02 1.62000000e+02]\n",
" [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n",
" 9.57810428e+01 1.42000000e+02]\n",
" ...\n",
" [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n",
" 7.84053656e+01 9.00000000e+01]\n",
" [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n",
" 5.13101944e+01 3.50000000e+01]\n",
" [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n",
" 6.36197377e+01 4.40000000e+01]]\n",
"False\n"
]
}
],
"source": [
"f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n",
" dither=0.0,remove_dc_offset=False, preemph=1.0, \n",
" wintype='hamming')\n",
"\n",
"n_fft=512\n",
"freq_bins = n_fft//2+1\n",
"s = np.arange(0, n_fft, 1.)\n",
"wsin = np.empty((freq_bins,1,400))\n",
"wcos = np.empty((freq_bins,1,400)) #[Cout, Cin, kernel_size]\n",
"for k in range(freq_bins): # Only half of the bins contain useful info\n",
" wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n",
" wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n",
"\n",
" \n",
"x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n",
"x = x[None, None, :] #[B, C, T]\n",
"\n",
"kernel_sin = torch.tensor(wsin, dtype=torch.float)\n",
"kernel_cos = torch.tensor(wcos, dtype=torch.float)\n",
"print(kernel_sin.size())\n",
"\n",
"from torch.nn.functional import conv1d, conv2d, fold\n",
"spec_imag = conv1d(x, kernel_sin, stride=160) #[1, Cout, T]\n",
"spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n",
"\n",
"# remove redundant parts\n",
"spec = spec_real.pow(2) + spec_imag.pow(2)\n",
"spec = torch.sqrt(spec)\n",
"print(spec)\n",
"\n",
"x = spec[0].T.numpy()\n",
"y = rspec[:, :]\n",
"print(x)\n",
"print(y)\n",
"print(np.allclose(x, y))"
]
},
{
"cell_type": "code",
"execution_count": 162,
"id": "olive-nicaragua",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: RuntimeWarning: divide by zero encountered in true_divide\n",
" \"\"\"Entry point for launching an IPython kernel.\n"
]
},
{
"data": {
"text/plain": [
"27241"
]
},
"execution_count": 162,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.argmax(np.abs(x -y) / np.abs(y))"
]
},
{
"cell_type": "code",
"execution_count": 165,
"id": "ultimate-assault",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.0"
]
},
"execution_count": 165,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y[np.unravel_index(27241, y.shape)]"
]
},
{
"cell_type": "code",
"execution_count": 166,
"id": "institutional-stock",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4.2412265e-10"
]
},
"execution_count": 166,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x[np.unravel_index(27241, y.shape)]"
]
},
{
"cell_type": "code",
"execution_count": 167,
"id": "integrated-courage",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 167,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.allclose(y, x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "different-operation",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
......@@ -16,7 +16,7 @@
## Setup
* python>=3.7
* paddlepaddle>=2.1.0
* paddlepaddle>=2.1.2
Please see [install](doc/src/install.md).
......
......@@ -17,7 +17,7 @@
## 安装
* python>=3.7
* paddlepaddle>=2.1.0
* paddlepaddle>=2.1.2
参看 [安装](doc/src/install.md)
......
......@@ -18,8 +18,10 @@ import numpy as np
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
from paddle.io import DataLoader
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.training.cli import default_argument_parser
......@@ -78,26 +80,31 @@ def inference(config, args):
def start_server(config, args):
"""Start the ASR server"""
config.defrost()
config.data.manfiest = config.data.test_manifest
config.data.augmentation_config = ""
config.data.keep_transcription_text = True
config.data.manifest = config.data.test_manifest
dataset = ManifestDataset.from_config(config)
model = DeepSpeech2Model.from_pretrained(dataset, config,
config.collator.augmentation_config = ""
config.collator.keep_transcription_text = True
config.collator.batch_size = 1
config.collator.num_workers = 0
collate_fn = SpeechCollator.from_config(config)
test_loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=0)
model = DeepSpeech2Model.from_pretrained(test_loader, config,
args.checkpoint_path)
model.eval()
# prepare ASR inference handler
def file_to_transcript(filename):
feature = dataset.process_utterance(filename, "")
audio = np.array([feature[0]]).astype('float32') #[1, D, T]
audio_len = feature[0].shape[1]
feature = test_loader.collate_fn.process_utterance(filename, "")
audio = np.array([feature[0]]).astype('float32') #[1, T, D]
audio_len = feature[0].shape[0]
audio_len = np.array([audio_len]).astype('int64') # [1]
result_transcript = model.decode(
paddle.to_tensor(audio),
paddle.to_tensor(audio_len),
vocab_list=dataset.vocab_list,
vocab_list=test_loader.collate_fn.vocab_list,
decoding_method=config.decoding.decoding_method,
lang_model_path=config.decoding.lang_model_path,
beam_alpha=config.decoding.alpha,
......@@ -138,7 +145,7 @@ if __name__ == "__main__":
add_arg('host_ip', str,
'localhost',
"Server's IP address.")
add_arg('host_port', int, 8086, "Server's IP port.")
add_arg('host_port', int, 8089, "Server's IP port.")
add_arg('speech_save_dir', str,
'demo_cache',
"Directory to save demo audios.")
......
......@@ -16,8 +16,10 @@ import functools
import numpy as np
import paddle
from paddle.io import DataLoader
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.training.cli import default_argument_parser
......@@ -31,26 +33,35 @@ from deepspeech.utils.utility import print_arguments
def start_server(config, args):
"""Start the ASR server"""
config.defrost()
config.data.manfiest = config.data.test_manifest
config.data.augmentation_config = ""
config.data.keep_transcription_text = True
config.data.manifest = config.data.test_manifest
dataset = ManifestDataset.from_config(config)
model = DeepSpeech2Model.from_pretrained(dataset, config,
config.collator.augmentation_config = ""
config.collator.keep_transcription_text = True
config.collator.batch_size = 1
config.collator.num_workers = 0
collate_fn = SpeechCollator.from_config(config)
test_loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=0)
model = DeepSpeech2Model.from_pretrained(test_loader, config,
args.checkpoint_path)
model.eval()
# prepare ASR inference handler
def file_to_transcript(filename):
feature = dataset.process_utterance(filename, "")
audio = np.array([feature[0]]).astype('float32') #[1, D, T]
audio_len = feature[0].shape[1]
feature = test_loader.collate_fn.process_utterance(filename, "")
audio = np.array([feature[0]]).astype('float32') #[1, T, D]
# audio = audio.swapaxes(1,2)
print('---file_to_transcript feature----')
print(audio.shape)
audio_len = feature[0].shape[0]
print(audio_len)
audio_len = np.array([audio_len]).astype('int64') # [1]
result_transcript = model.decode(
paddle.to_tensor(audio),
paddle.to_tensor(audio_len),
vocab_list=dataset.vocab_list,
vocab_list=test_loader.collate_fn.vocab_list,
decoding_method=config.decoding.decoding_method,
lang_model_path=config.decoding.lang_model_path,
beam_alpha=config.decoding.alpha,
......@@ -91,7 +102,7 @@ if __name__ == "__main__":
add_arg('host_ip', str,
'localhost',
"Server's IP address.")
add_arg('host_port', int, 8086, "Server's IP port.")
add_arg('host_port', int, 8088, "Server's IP port.")
add_arg('speech_save_dir', str,
'demo_cache',
"Directory to save demo audios.")
......
......@@ -30,11 +30,15 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--model_type")
args = parser.parse_args()
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
print_arguments(args)
# https://yaml.org/type/float.html
config = get_cfg_defaults()
config = get_cfg_defaults(args.model_type)
if args.config:
config.merge_from_file(args.config)
if args.opts:
......
......@@ -30,11 +30,15 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--model_type")
args = parser.parse_args()
print_arguments(args, globals())
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
# https://yaml.org/type/float.html
config = get_cfg_defaults()
config = get_cfg_defaults(args.model_type)
if args.config:
config.merge_from_file(args.config)
if args.opts:
......
......@@ -35,11 +35,15 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--model_type")
args = parser.parse_args()
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = get_cfg_defaults()
config = get_cfg_defaults(args.model_type)
if args.config:
config.merge_from_file(args.config)
if args.opts:
......
......@@ -47,7 +47,7 @@ def tune(config, args):
drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True))
model = DeepSpeech2Model.from_pretrained(dev_dataset, config,
model = DeepSpeech2Model.from_pretrained(valid_loader, config,
args.checkpoint_path)
model.eval()
......
......@@ -11,77 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from yacs.config import CfgNode as CN
from deepspeech.models.deepspeech2 import DeepSpeech2Model
_C = CN()
_C.data = CN(
dict(
train_manifest="",
dev_manifest="",
test_manifest="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
mean_std_filepath="",
augmentation_config="",
max_duration=float('inf'),
min_duration=0.0,
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delat_delta=False, # 'mfcc', 'fbank'
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
random_seed=0,
keep_transcription_text=False,
batch_size=32, # batch size
num_workers=0, # data loader workers
sortagrad=False, # sorted in first epoch when True
shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle'
))
_C.model = CN(
dict(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=3, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
))
DeepSpeech2Model.params(_C.model)
_C.training = CN(
dict(
lr=5e-4, # learning rate
lr_decay=1.0, # learning rate decay
weight_decay=1e-6, # the coeff of weight decay
global_grad_clip=5.0, # the global norm clip
n_epoch=50, # train epochs
))
_C.decoding = CN(
dict(
alpha=2.5, # Coef of LM for beam search.
beta=0.3, # Coef of WC for beam search.
cutoff_prob=1.0, # Cutoff probability for pruning.
cutoff_top_n=40, # Cutoff number for pruning.
lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model.
decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy
error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer'
num_proc_bsearch=8, # # of CPUs for beam search.
beam_size=500, # Beam search width.
batch_size=128, # decoding batch size
))
def get_cfg_defaults():
from yacs.config import CfgNode
from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester
from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
def get_cfg_defaults(model_type='offline'):
_C = CfgNode()
_C.data = ManifestDataset.params()
_C.collator = SpeechCollator.params()
_C.training = DeepSpeech2Trainer.params()
_C.decoding = DeepSpeech2Tester.params()
if model_type == 'offline':
_C.model = DeepSpeech2Model.params()
else:
_C.model = DeepSpeech2ModelOnline.params()
"""Get a yacs CfgNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
......
......@@ -11,39 +11,61 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains DeepSpeech2 model."""
"""Contains DeepSpeech2 and DeepSpeech2Online model."""
import time
from collections import defaultdict
from pathlib import Path
from typing import Optional
import numpy as np
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from yacs.config import CfgNode
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.deepspeech2 import DeepSpeech2InferModel
from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.models.ds2 import DeepSpeech2InferModel
from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.trainer import Trainer
from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
from deepspeech.utils.log import Autolog
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
class DeepSpeech2Trainer(Trainer):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
# training config
default = CfgNode(
dict(
lr=5e-4, # learning rate
lr_decay=1.0, # learning rate decay
weight_decay=1e-6, # the coeff of weight decay
global_grad_clip=5.0, # the global norm clip
n_epoch=50, # train epochs
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self, config, args):
super().__init__(config, args)
def train_batch(self, batch_index, batch_data, msg):
start = time.time()
loss = self.model(*batch_data)
utt, audio, audio_len, text, text_len = batch_data
loss = self.model(audio, audio_len, text, text_len)
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
self.optimizer.step()
......@@ -54,7 +76,7 @@ class DeepSpeech2Trainer(Trainer):
'train_loss': float(loss),
}
msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.data.batch_size)
msg += "batch size: {}, ".format(self.config.collator.batch_size)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
logger.info(msg)
......@@ -73,9 +95,10 @@ class DeepSpeech2Trainer(Trainer):
num_seen_utts = 1
total_loss = 0.0
for i, batch in enumerate(self.valid_loader):
loss = self.model(*batch)
utt, audio, audio_len, text, text_len = batch
loss = self.model(audio, audio_len, text, text_len)
if paddle.isfinite(loss):
num_utts = batch[0].shape[0]
num_utts = batch[1].shape[0]
num_seen_utts += num_utts
total_loss += float(loss) * num_utts
valid_losses['val_loss'].append(float(loss))
......@@ -98,16 +121,18 @@ class DeepSpeech2Trainer(Trainer):
return total_loss, num_seen_utts
def setup_model(self):
config = self.config
model = DeepSpeech2Model(
feat_size=self.train_loader.dataset.feature_size,
dict_size=self.train_loader.dataset.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
config = self.config.clone()
config.defrost()
config.model.feat_size = self.train_loader.collate_fn.feature_size
config.model.dict_size = self.train_loader.collate_fn.vocab_size
config.freeze()
if self.args.model_type == 'offline':
model = DeepSpeech2Model.from_config(config.model)
elif self.args.model_type == 'online':
model = DeepSpeech2ModelOnline.from_config(config.model)
else:
raise Exception("wrong model type")
if self.parallel:
model = paddle.DataParallel(model)
......@@ -135,50 +160,87 @@ class DeepSpeech2Trainer(Trainer):
def setup_dataloader(self):
config = self.config.clone()
config.defrost()
config.data.keep_transcription_text = False
config.collator.keep_transcription_text = False
config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.dev_manifest
config.data.augmentation_config = ""
dev_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.test_manifest
test_dataset = ManifestDataset.from_config(config)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
batch_size=config.data.batch_size,
batch_size=config.collator.batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=True,
sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method)
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
else:
batch_sampler = SortagradBatchSampler(
train_dataset,
shuffle=True,
batch_size=config.data.batch_size,
batch_size=config.collator.batch_size,
drop_last=True,
sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method)
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
collate_fn_test = SpeechCollator.from_config(config)
collate_fn = SpeechCollator(keep_transcription_text=False)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
num_workers=config.data.num_workers)
collate_fn=collate_fn_train,
num_workers=config.collator.num_workers)
self.valid_loader = DataLoader(
dev_dataset,
batch_size=config.data.batch_size,
batch_size=config.collator.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn)
logger.info("Setup train/valid Dataloader!")
collate_fn=collate_fn_dev)
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_test)
logger.info("Setup train/valid/test Dataloader!")
class DeepSpeech2Tester(DeepSpeech2Trainer):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
# testing config
default = CfgNode(
dict(
alpha=2.5, # Coef of LM for beam search.
beta=0.3, # Coef of WC for beam search.
cutoff_prob=1.0, # Cutoff probability for pruning.
cutoff_top_n=40, # Cutoff number for pruning.
lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model.
decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy
error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer'
num_proc_bsearch=8, # # of CPUs for beam search.
beam_size=500, # Beam search width.
batch_size=128, # decoding batch size
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self, config, args):
super().__init__(config, args)
......@@ -191,15 +253,23 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
trans.append(''.join([chr(i) for i in ids]))
return trans
def compute_metrics(self, audio, audio_len, texts, texts_len):
def compute_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
fout=None):
cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
vocab_list = self.test_loader.dataset.vocab_list
vocab_list = self.test_loader.collate_fn.vocab_list
target_transcripts = self.ordid2token(texts, texts_len)
self.autolog.times.start()
self.autolog.times.stamp()
result_transcripts = self.model.decode(
audio,
audio_len,
......@@ -212,12 +282,18 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
self.autolog.times.stamp()
self.autolog.times.stamp()
self.autolog.times.end()
for target, result in zip(target_transcripts, result_transcripts):
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
logger.info("Current error rate [%s] = %f" %
......@@ -234,19 +310,25 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
@paddle.no_grad()
def test(self):
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
self.autolog = Autolog(
batch_size=self.config.decoding.batch_size,
model_name="deepspeech2",
model_precision="fp32").getlog()
self.model.eval()
cfg = self.config
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch)
errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
error_rate_type = metrics['error_rate_type']
logger.info("Error rate [%s] (%d/?) = %f" %
(error_rate_type, num_ins, errors_sum / len_refs))
with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch
metrics = self.compute_metrics(utts, audio, audio_len, texts,
texts_len, fout)
errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
error_rate_type = metrics['error_rate_type']
logger.info("Error rate [%s] (%d/?) = %f" %
(error_rate_type, num_ins, errors_sum / len_refs))
# logging
msg = "Test: "
......@@ -255,6 +337,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
msg += "Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg)
self.autolog.report()
def run_test(self):
self.resume_or_scratch()
......@@ -264,19 +347,18 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
exit(-1)
def export(self):
infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader.dataset, self.config, self.args.checkpoint_path)
if self.args.model_type == 'offline':
infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
elif self.args.model_type == 'online':
infer_model = DeepSpeech2InferModelOnline.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
else:
raise Exception("wrong model type")
infer_model.eval()
feat_dim = self.test_loader.dataset.feature_size
static_model = paddle.jit.to_static(
infer_model,
input_spec=[
paddle.static.InputSpec(
shape=[None, None, feat_dim],
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
])
feat_dim = self.test_loader.collate_fn.feature_size
static_model = infer_model.export()
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
......@@ -300,46 +382,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self.iteration = 0
self.epoch = 0
def setup_model(self):
config = self.config
model = DeepSpeech2Model(
feat_size=self.test_loader.dataset.feature_size,
dict_size=self.test_loader.dataset.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
self.model = model
logger.info("Setup model!")
def setup_dataloader(self):
config = self.config.clone()
config.defrost()
# return raw text
config.data.manifest = config.data.test_manifest
config.data.keep_transcription_text = True
config.data.augmentation_config = ""
# filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now.
# config.data.min_input_len = 0.0 # second
# config.data.max_input_len = float('inf') # second
# config.data.min_output_len = 0.0 # tokens
# config.data.max_output_len = float('inf') # tokens
# config.data.min_output_input_ratio = 0.00
# config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config)
# return text ord id
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True))
logger.info("Setup test Dataloader!")
def setup_output_dir(self):
"""Create a directory used for output.
"""
......
......@@ -15,6 +15,7 @@ from yacs.config import CfgNode
from deepspeech.exps.u2.model import U2Tester
from deepspeech.exps.u2.model import U2Trainer
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.u2 import U2Model
......@@ -22,6 +23,8 @@ _C = CfgNode()
_C.data = ManifestDataset.params()
_C.collator = SpeechCollator.params()
_C.model = U2Model.params()
_C.training = U2Trainer.params()
......
......@@ -31,12 +31,15 @@ from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.u2 import U2Model
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.scheduler import WarmupLR
from deepspeech.training.optimizer import OptimizerFactory
from deepspeech.training.scheduler import LRSchedulerFactory
from deepspeech.training.trainer import Trainer
from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
......@@ -76,8 +79,10 @@ class U2Trainer(Trainer):
def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training
start = time.time()
utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(*batch_data)
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
loss.backward()
......@@ -99,7 +104,7 @@ class U2Trainer(Trainer):
if (batch_index + 1) % train_conf.log_interval == 0:
msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.data.batch_size)
msg += "batch size: {}, ".format(self.config.collator.batch_size)
msg += "accum: {}, ".format(train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
......@@ -119,9 +124,11 @@ class U2Trainer(Trainer):
num_seen_utts = 1
total_loss = 0.0
for i, batch in enumerate(self.valid_loader):
loss, attention_loss, ctc_loss = self.model(*batch)
utt, audio, audio_len, text, text_len = batch
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
if paddle.isfinite(loss):
num_utts = batch[0].shape[0]
num_utts = batch[1].shape[0]
num_seen_utts += num_utts
total_loss += float(loss) * num_utts
valid_losses['val_loss'].append(float(loss))
......@@ -209,51 +216,52 @@ class U2Trainer(Trainer):
def setup_dataloader(self):
config = self.config.clone()
config.defrost()
config.data.keep_transcription_text = False
config.collator.keep_transcription_text = False
# train/valid dataset, return token ids
config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.dev_manifest
config.data.augmentation_config = ""
dev_dataset = ManifestDataset.from_config(config)
collate_fn = SpeechCollator(keep_transcription_text=False)
collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
batch_size=config.data.batch_size,
batch_size=config.collator.batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=True,
sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method)
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
else:
batch_sampler = SortagradBatchSampler(
train_dataset,
shuffle=True,
batch_size=config.data.batch_size,
batch_size=config.collator.batch_size,
drop_last=True,
sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method)
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
num_workers=config.data.num_workers, )
collate_fn=collate_fn_train,
num_workers=config.collator.num_workers, )
self.valid_loader = DataLoader(
dev_dataset,
batch_size=config.data.batch_size,
batch_size=config.collator.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn)
collate_fn=collate_fn_dev)
# test dataset, return raw text
config.data.manifest = config.data.test_manifest
config.data.keep_transcription_text = True
config.data.augmentation_config = ""
# filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now.
# config.data.min_input_len = 0.0 # second
......@@ -262,22 +270,33 @@ class U2Trainer(Trainer):
# config.data.max_output_len = float('inf') # tokens
# config.data.min_output_input_ratio = 0.00
# config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config)
# return text ord id
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True))
logger.info("Setup train/valid/test Dataloader!")
collate_fn=SpeechCollator.from_config(config))
# return text token id
config.collator.keep_transcription_text = False
self.align_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
logger.info("Setup train/valid/test/align Dataloader!")
def setup_model(self):
config = self.config
model_conf = config.model
model_conf.defrost()
model_conf.input_dim = self.train_loader.dataset.feature_size
model_conf.output_dim = self.train_loader.dataset.vocab_size
model_conf.input_dim = self.train_loader.collate_fn.feature_size
model_conf.output_dim = self.train_loader.collate_fn.vocab_size
model_conf.freeze()
model = U2Model.from_config(model_conf)
......@@ -293,30 +312,38 @@ class U2Trainer(Trainer):
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
grad_clip = ClipGradByGlobalNormWithLog(train_config.global_grad_clip)
weight_decay = paddle.regularizer.L2Decay(optim_conf.weight_decay)
if scheduler_type == 'expdecaylr':
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
learning_rate=optim_conf.lr,
gamma=scheduler_conf.lr_decay,
verbose=False)
elif scheduler_type == 'warmuplr':
lr_scheduler = WarmupLR(
learning_rate=optim_conf.lr,
warmup_steps=scheduler_conf.warmup_steps,
verbose=False)
else:
raise ValueError(f"Not support scheduler: {scheduler_type}")
if optim_type == 'adam':
optimizer = paddle.optimizer.Adam(
learning_rate=lr_scheduler,
parameters=model.parameters(),
weight_decay=weight_decay,
grad_clip=grad_clip)
else:
raise ValueError(f"Not support optim: {optim_type}")
scheduler_args = {
"learning_rate": optim_conf.lr,
"verbose": False,
"warmup_steps": scheduler_conf.warmup_steps,
"gamma": scheduler_conf.lr_decay,
"d_model": model_conf.encoder_conf.output_size,
}
lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
scheduler_args)
def optimizer_args(
config,
parameters,
lr_scheduler=None, ):
train_config = config.training
optim_type = train_config.optim
optim_conf = train_config.optim_conf
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
return {
"grad_clip": train_config.global_grad_clip,
"weight_decay": optim_conf.weight_decay,
"learning_rate": lr_scheduler
if lr_scheduler else optim_conf.lr,
"parameters": parameters,
"epsilon": 1e-9 if optim_type == 'noam' else None,
"beta1": 0.9 if optim_type == 'noam' else None,
"beat2": 0.98 if optim_type == 'noam' else None,
}
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
optimizer = OptimizerFactory.from_args(optim_type, optimzer_args)
self.model = model
self.optimizer = optimizer
......@@ -345,7 +372,7 @@ class U2Tester(U2Trainer):
decoding_chunk_size=-1, # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1.
simulate_streaming=False, # simulate streaming inference. Defaults to False.
))
......@@ -366,14 +393,20 @@ class U2Tester(U2Trainer):
trans.append(''.join([chr(i) for i in ids]))
return trans
def compute_metrics(self, audio, audio_len, texts, texts_len, fout=None):
def compute_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
fout=None):
cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time()
text_feature = self.test_loader.dataset.text_feature
text_feature = self.test_loader.collate_fn.text_feature
target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.model.decode(
audio,
......@@ -393,13 +426,14 @@ class U2Tester(U2Trainer):
simulate_streaming=cfg.simulate_streaming)
decode_time = time.time() - start_time
for target, result in zip(target_transcripts, result_transcripts):
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
if fout:
fout.write(result + "\n")
fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
logger.info("One example error rate [%s] = %f" %
......@@ -421,7 +455,7 @@ class U2Tester(U2Trainer):
self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.test_loader.dataset.stride_ms
stride_ms = self.test_loader.collate_fn.stride_ms
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0
......@@ -483,6 +517,73 @@ class U2Tester(U2Trainer):
except KeyboardInterrupt:
sys.exit(-1)
@paddle.no_grad()
def align(self):
if self.config.decoding.batch_size > 1:
logger.fatal('alignment mode must be running with batch_size == 1')
sys.exit(1)
# xxx.align
assert self.args.result_file and self.args.result_file.endswith(
'.align')
self.model.eval()
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.align_loader.collate_fn.stride_ms
token_dict = self.align_loader.collate_fn.vocab_list
with open(self.args.result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(self.align_loader):
key, feat, feats_length, target, target_length = batch
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
# 2. alignment
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info("align ids", key[0], alignment)
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
logger.info("align tokens", key[0], align_segs)
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
# write tier
align_output_path = os.path.join(
os.path.dirname(self.args.result_file), "align")
tier_path = os.path.join(align_output_path, key[0] + ".tier")
with open(tier_path, 'w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = os.path.join(align_output_path,
key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
len(alignment) + 1) * subsample * second_per_frame
text_grid.generate_textgrid(
maxtime=second_per_example,
intervals=tierformat,
output=textgrid_path)
def run_align(self):
self.resume_or_scratch()
try:
self.align()
except KeyboardInterrupt:
sys.exit(-1)
def load_inferspec(self):
"""infer model and input spec.
......@@ -491,15 +592,14 @@ class U2Tester(U2Trainer):
List[paddle.static.InputSpec]: input spec.
"""
from deepspeech.models.u2 import U2InferModel
infer_model = U2InferModel.from_pretrained(self.test_loader.dataset,
infer_model = U2InferModel.from_pretrained(self.test_loader,
self.config.model.clone(),
self.args.checkpoint_path)
feat_dim = self.test_loader.dataset.feature_size
feat_dim = self.test_loader.collate_fn.feature_size
input_spec = [
paddle.static.InputSpec(
shape=[None, feat_dim, None],
dtype='float32'), # audio, [B,D,T]
paddle.static.InputSpec(shape=[None],
paddle.static.InputSpec(shape=[1, None, feat_dim],
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[1],
dtype='int64'), # audio_length, [B]
]
return infer_model, input_spec
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Export for U2 model."""
from deepspeech.exps.u2_st.config import get_cfg_defaults
from deepspeech.exps.u2_st.model import U2STTester as Tester
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
def main_sp(config, args):
exp = Tester(config, args)
exp.setup()
exp.run_export()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for U2 model."""
import cProfile
from deepspeech.exps.u2_st.config import get_cfg_defaults
from deepspeech.exps.u2_st.model import U2STTester as Tester
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
# TODO(hui zhang): dynamic load
def main_sp(config, args):
exp = Tester(config, args)
exp.setup()
exp.run_test()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats('test.profile')
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer for U2 model."""
import cProfile
import os
from paddle import distributed as dist
from deepspeech.exps.u2_st.config import get_cfg_defaults
from deepspeech.exps.u2_st.model import U2STTrainer as Trainer
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
def main_sp(config, args):
exp = Trainer(config, args)
exp.setup()
exp.run()
def main(config, args):
if args.device == "gpu" and args.nprocs > 1:
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
else:
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats(os.path.join(args.output, 'train.profile'))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from yacs.config import CfgNode
from deepspeech.exps.u2_st.model import U2STTester
from deepspeech.exps.u2_st.model import U2STTrainer
from deepspeech.io.collator_st import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.u2_st import U2STModel
_C = CfgNode()
_C.data = ManifestDataset.params()
_C.collator = SpeechCollator.params()
_C.model = U2STModel.params()
_C.training = U2STTrainer.params()
_C.decoding = U2STTester.params()
def get_cfg_defaults():
"""Get a yacs CfgNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
config = _C.clone()
config.set_new_allowed(True)
return config
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains U2 model."""
import json
import os
import sys
import time
from collections import defaultdict
from pathlib import Path
from typing import Optional
import numpy as np
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from yacs.config import CfgNode
from deepspeech.io.collator_st import KaldiPrePorocessedCollator
from deepspeech.io.collator_st import SpeechCollator
from deepspeech.io.collator_st import TripletKaldiPrePorocessedCollator
from deepspeech.io.collator_st import TripletSpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.dataset import TripletManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.u2_st import U2STModel
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.scheduler import WarmupLR
from deepspeech.training.trainer import Trainer
from deepspeech.utils import bleu_score
from deepspeech.utils import ctc_utils
from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
class U2STTrainer(Trainer):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
# training config
default = CfgNode(
dict(
n_epoch=50, # train epochs
log_interval=100, # steps
accum_grad=1, # accum grad by # steps
global_grad_clip=5.0, # the global norm clip
))
default.optim = 'adam'
default.optim_conf = CfgNode(
dict(
lr=5e-4, # learning rate
weight_decay=1e-6, # the coeff of weight decay
))
default.scheduler = 'warmuplr'
default.scheduler_conf = CfgNode(
dict(
warmup_steps=25000,
lr_decay=1.0, # learning rate decay
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self, config, args):
super().__init__(config, args)
def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training
start = time.time()
utt, audio, audio_len, text, text_len = batch_data
if isinstance(text, list) and isinstance(text_len, list):
# joint training with ASR. Two decoding texts [translation, transcription]
text, text_transcript = text
text_len, text_transcript_len = text_len
loss, st_loss, attention_loss, ctc_loss = self.model(
audio, audio_len, text, text_len, text_transcript,
text_transcript_len)
else:
loss, st_loss, attention_loss, ctc_loss = self.model(
audio, audio_len, text, text_len)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
losses_np = {'loss': float(loss) * train_conf.accum_grad}
losses_np['st_loss'] = float(st_loss)
if attention_loss:
losses_np['att_loss'] = float(attention_loss)
if ctc_loss:
losses_np['ctc_loss'] = float(ctc_loss)
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
self.optimizer.clear_grad()
self.lr_scheduler.step()
self.iteration += 1
iteration_time = time.time() - start
if (batch_index + 1) % train_conf.log_interval == 0:
msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.collator.batch_size)
msg += "accum: {}, ".format(train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
logger.info(msg)
if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()})
self.visualizer.add_scalars("step", losses_np_v,
self.iteration - 1)
@paddle.no_grad()
def valid(self):
self.model.eval()
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
for i, batch in enumerate(self.valid_loader):
utt, audio, audio_len, text, text_len = batch
if isinstance(text, list) and isinstance(text_len, list):
text, text_transcript = text
text_len, text_transcript_len = text_len
loss, st_loss, attention_loss, ctc_loss = self.model(
audio, audio_len, text, text_len, text_transcript,
text_transcript_len)
else:
loss, st_loss, attention_loss, ctc_loss = self.model(
audio, audio_len, text, text_len)
if paddle.isfinite(loss):
num_utts = batch[1].shape[0]
num_seen_utts += num_utts
total_loss += float(st_loss) * num_utts
valid_losses['val_loss'].append(float(st_loss))
if attention_loss:
valid_losses['val_att_loss'].append(float(attention_loss))
if ctc_loss:
valid_losses['val_ctc_loss'].append(float(ctc_loss))
if (i + 1) % self.config.training.log_interval == 0:
valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
valid_dump['val_history_st_loss'] = total_loss / num_seen_utts
# logging
msg = f"Valid: Rank: {dist.get_rank()}, "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
logger.info('Rank {} Val info st_val_loss {}'.format(
dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts
def train(self):
"""The training process control by step."""
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
# script_model = paddle.jit.to_static(self.model)
# script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
self.model.train()
try:
data_start_time = time.time()
for batch_index, batch in enumerate(self.train_loader):
dataload_time = time.time() - data_start_time
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch : {}/{}, ".format(batch_index + 1,
len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg)
data_start_time = time.time()
except Exception as e:
logger.error(e)
raise e
total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts)
# the default operator in all_reduce function is sum.
dist.all_reduce(num_seen_utts)
total_loss = paddle.to_tensor(total_loss)
dist.all_reduce(total_loss)
cv_loss = total_loss / num_seen_utts
cv_loss = float(cv_loss)
else:
cv_loss = total_loss / num_seen_utts
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalars(
'epoch', {'cv_loss': cv_loss,
'lr': self.lr_scheduler()}, self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch()
def setup_dataloader(self):
config = self.config.clone()
config.defrost()
config.collator.keep_transcription_text = False
# train/valid dataset, return token ids
Dataset = TripletManifestDataset if config.model.model_conf.asr_weight > 0. else ManifestDataset
config.data.manifest = config.data.train_manifest
train_dataset = Dataset.from_config(config)
config.data.manifest = config.data.dev_manifest
dev_dataset = Dataset.from_config(config)
if config.collator.raw_wav:
if config.model.model_conf.asr_weight > 0.:
Collator = TripletSpeechCollator
TestCollator = SpeechCollator
else:
TestCollator = Collator = SpeechCollator
# Not yet implement the mtl loader for raw_wav.
else:
if config.model.model_conf.asr_weight > 0.:
Collator = TripletKaldiPrePorocessedCollator
TestCollator = KaldiPrePorocessedCollator
else:
TestCollator = Collator = KaldiPrePorocessedCollator
collate_fn_train = Collator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = Collator.from_config(config)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
batch_size=config.collator.batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
else:
batch_sampler = SortagradBatchSampler(
train_dataset,
shuffle=True,
batch_size=config.collator.batch_size,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn_train,
num_workers=config.collator.num_workers, )
self.valid_loader = DataLoader(
dev_dataset,
batch_size=config.collator.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev)
# test dataset, return raw text
config.data.manifest = config.data.test_manifest
# filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now.
# config.data.min_input_len = 0.0 # second
# config.data.max_input_len = float('inf') # second
# config.data.min_output_len = 0.0 # tokens
# config.data.max_output_len = float('inf') # tokens
# config.data.min_output_input_ratio = 0.00
# config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config)
# return text ord id
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=TestCollator.from_config(config))
# return text token id
config.collator.keep_transcription_text = False
self.align_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=TestCollator.from_config(config))
logger.info("Setup train/valid/test/align Dataloader!")
def setup_model(self):
config = self.config
model_conf = config.model
model_conf.defrost()
model_conf.input_dim = self.train_loader.collate_fn.feature_size
model_conf.output_dim = self.train_loader.collate_fn.vocab_size
model_conf.freeze()
model = U2STModel.from_config(model_conf)
if self.parallel:
model = paddle.DataParallel(model)
logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
train_config = config.training
optim_type = train_config.optim
optim_conf = train_config.optim_conf
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
if scheduler_type == 'expdecaylr':
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
learning_rate=optim_conf.lr,
gamma=scheduler_conf.lr_decay,
verbose=False)
elif scheduler_type == 'warmuplr':
lr_scheduler = WarmupLR(
learning_rate=optim_conf.lr,
warmup_steps=scheduler_conf.warmup_steps,
verbose=False)
elif scheduler_type == 'noam':
lr_scheduler = paddle.optimizer.lr.NoamDecay(
learning_rate=optim_conf.lr,
d_model=model_conf.encoder_conf.output_size,
warmup_steps=scheduler_conf.warmup_steps,
verbose=False)
else:
raise ValueError(f"Not support scheduler: {scheduler_type}")
grad_clip = ClipGradByGlobalNormWithLog(train_config.global_grad_clip)
weight_decay = paddle.regularizer.L2Decay(optim_conf.weight_decay)
if optim_type == 'adam':
optimizer = paddle.optimizer.Adam(
learning_rate=lr_scheduler,
parameters=model.parameters(),
weight_decay=weight_decay,
grad_clip=grad_clip)
else:
raise ValueError(f"Not support optim: {optim_type}")
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
logger.info("Setup model/optimizer/lr_scheduler!")
class U2STTester(U2STTrainer):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
# decoding config
default = CfgNode(
dict(
alpha=2.5, # Coef of LM for beam search.
beta=0.3, # Coef of WC for beam search.
cutoff_prob=1.0, # Cutoff probability for pruning.
cutoff_top_n=40, # Cutoff number for pruning.
lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model.
decoding_method='attention', # Decoding method. Options: 'attention', 'ctc_greedy_search',
# 'ctc_prefix_beam_search', 'attention_rescoring'
error_rate_type='bleu', # Error rate type for evaluation. Options `bleu`, 'char_bleu'
num_proc_bsearch=8, # # of CPUs for beam search.
beam_size=10, # Beam search width.
batch_size=16, # decoding batch size
ctc_weight=0.0, # ctc weight for attention rescoring decode mode.
decoding_chunk_size=-1, # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1.
simulate_streaming=False, # simulate streaming inference. Defaults to False.
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self, config, args):
super().__init__(config, args)
def ordid2token(self, texts, texts_len):
""" ord() id to chr() chr """
trans = []
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(''.join([chr(i) for i in ids]))
return trans
def compute_translation_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
bleu_func,
fout=None):
cfg = self.config.decoding
len_refs, num_ins = 0, 0
start_time = time.time()
text_feature = self.test_loader.collate_fn.text_feature
refs = [
"".join(chr(t) for t in text[:text_len])
for text, text_len in zip(texts, texts_len)
]
# from IPython import embed
# import os
# embed()
# os._exit(0)
hyps = self.model.decode(
audio,
audio_len,
text_feature=text_feature,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
decode_time = time.time() - start_time
for utt, target, result in zip(utts, refs, hyps):
len_refs += len(target.split())
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nReference: %s\nHypothesis: %s" % (target, result))
logger.info("One example BLEU = %s" %
(bleu_func([result], [[target]]).prec_str))
return dict(
hyps=hyps,
refs=refs,
bleu=bleu_func(hyps, [refs]).score,
len_refs=len_refs,
num_ins=num_ins, # num examples
num_frames=audio_len.sum().numpy().item(),
decode_time=decode_time)
@mp_tools.rank_zero_only
@paddle.no_grad()
def test(self):
assert self.args.result_file
self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
cfg = self.config.decoding
bleu_func = bleu_score.char_bleu if cfg.error_rate_type == 'char-bleu' else bleu_score.bleu
stride_ms = self.test_loader.collate_fn.stride_ms
hyps, refs = [], []
len_refs, num_ins = 0, 0
num_frames = 0.0
num_time = 0.0
with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
metrics = self.compute_translation_metrics(
*batch, bleu_func=bleu_func, fout=fout)
hyps += metrics['hyps']
refs += metrics['refs']
bleu = metrics['bleu']
num_frames += metrics['num_frames']
num_time += metrics["decode_time"]
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
rtf = num_time / (num_frames * stride_ms)
logger.info("RTF: %f, BELU (%d) = %f" % (rtf, num_ins, bleu))
rtf = num_time / (num_frames * stride_ms)
msg = "Test: "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "RTF: {}, ".format(rtf)
msg += "Test set [%s]: %s" % (len(hyps), str(bleu_func(hyps, [refs])))
logger.info(msg)
bleu_meta_path = os.path.splitext(self.args.result_file)[0] + '.bleu'
err_type_str = "BLEU"
with open(bleu_meta_path, 'w') as f:
data = json.dumps({
"epoch":
self.epoch,
"step":
self.iteration,
"rtf":
rtf,
err_type_str:
bleu_func(hyps, [refs]).score,
"dataset_hour": (num_frames * stride_ms) / 1000.0 / 3600.0,
"process_hour":
num_time / 1000.0 / 3600.0,
"num_examples":
num_ins,
"decode_method":
self.config.decoding.decoding_method,
})
f.write(data + '\n')
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
sys.exit(-1)
@paddle.no_grad()
def align(self):
if self.config.decoding.batch_size > 1:
logger.fatal('alignment mode must be running with batch_size == 1')
sys.exit(1)
# xxx.align
assert self.args.result_file and self.args.result_file.endswith(
'.align')
self.model.eval()
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.align_loader.collate_fn.stride_ms
token_dict = self.align_loader.collate_fn.vocab_list
with open(self.args.result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(self.align_loader):
key, feat, feats_length, target, target_length = batch
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
# 2. alignment
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info("align ids", key[0], alignment)
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
logger.info("align tokens", key[0], align_segs)
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
# write tier
align_output_path = os.path.join(
os.path.dirname(self.args.result_file), "align")
tier_path = os.path.join(align_output_path, key[0] + ".tier")
with open(tier_path, 'w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = os.path.join(align_output_path,
key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
len(alignment) + 1) * subsample * second_per_frame
text_grid.generate_textgrid(
maxtime=second_per_example,
intervals=tierformat,
output=textgrid_path)
def run_align(self):
self.resume_or_scratch()
try:
self.align()
except KeyboardInterrupt:
sys.exit(-1)
def load_inferspec(self):
"""infer model and input spec.
Returns:
nn.Layer: inference model
List[paddle.static.InputSpec]: input spec.
"""
from deepspeech.models.u2 import U2InferModel
infer_model = U2InferModel.from_pretrained(self.test_loader,
self.config.model.clone(),
self.args.checkpoint_path)
feat_dim = self.test_loader.collate_fn.feature_size
input_spec = [
paddle.static.InputSpec(shape=[1, None, feat_dim],
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[1],
dtype='int64'), # audio_length, [B]
]
return infer_model, input_spec
def export(self):
infer_model, input_spec = self.load_inferspec()
assert isinstance(input_spec, list), type(input_spec)
infer_model.eval()
static_model = paddle.jit.to_static(infer_model, input_spec=input_spec)
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
sys.exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device(self.args.device)
self.setup_output_dir()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
......@@ -107,7 +107,6 @@ class SpeechFeaturizer(object):
@property
def vocab_size(self):
"""Return the vocabulary size.
Returns:
int: Vocabulary size.
"""
......@@ -116,7 +115,6 @@ class SpeechFeaturizer(object):
@property
def vocab_list(self):
"""Return the vocabulary in list.
Returns:
List[str]:
"""
......@@ -125,7 +123,6 @@ class SpeechFeaturizer(object):
@property
def vocab_dict(self):
"""Return the vocabulary in dict.
Returns:
Dict[str, int]:
"""
......@@ -134,7 +131,6 @@ class SpeechFeaturizer(object):
@property
def feature_size(self):
"""Return the audio feature size.
Returns:
int: audio feature size.
"""
......@@ -143,7 +139,6 @@ class SpeechFeaturizer(object):
@property
def stride_ms(self):
"""time length in `ms` unit per frame
Returns:
float: time(ms)/frame
"""
......@@ -152,7 +147,6 @@ class SpeechFeaturizer(object):
@property
def text_feature(self):
"""Return the text feature object.
Returns:
TextFeaturizer: object.
"""
......
......@@ -11,8 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
from collections import namedtuple
from typing import Optional
import numpy as np
from yacs.config import CfgNode
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
from deepspeech.frontend.normalizer import FeatureNormalizer
from deepspeech.frontend.speech import SpeechSegment
from deepspeech.frontend.utility import IGNORE_ID
from deepspeech.io.utility import pad_sequence
from deepspeech.utils.log import Log
......@@ -21,17 +30,221 @@ __all__ = ["SpeechCollator"]
logger = Log(__name__).getlog()
# namedtupe need global for pickle.
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
class SpeechCollator():
def __init__(self, keep_transcription_text=True):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
augmentation_config="",
random_seed=0,
mean_std_filepath="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
dither=1.0, # feature dither
keep_transcription_text=False))
if config is not None:
config.merge_from_other_cfg(default)
return default
@classmethod
def from_config(cls, config):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
assert 'augmentation_config' in config.collator
assert 'keep_transcription_text' in config.collator
assert 'mean_std_filepath' in config.collator
assert 'vocab_filepath' in config.collator
assert 'specgram_type' in config.collator
assert 'n_fft' in config.collator
assert config.collator
if ``keep_transcription_text`` is False, text is token ids else is raw string.
if isinstance(config.collator.augmentation_config, (str, bytes)):
if config.collator.augmentation_config:
aug_file = io.open(
config.collator.augmentation_config,
mode='r',
encoding='utf8')
else:
aug_file = io.StringIO(initial_value='{}', newline='')
else:
aug_file = config.collator.augmentation_config
assert isinstance(aug_file, io.StringIO)
speech_collator = cls(
aug_file=aug_file,
random_seed=0,
mean_std_filepath=config.collator.mean_std_filepath,
unit_type=config.collator.unit_type,
vocab_filepath=config.collator.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix,
specgram_type=config.collator.specgram_type,
feat_dim=config.collator.feat_dim,
delta_delta=config.collator.delta_delta,
stride_ms=config.collator.stride_ms,
window_ms=config.collator.window_ms,
n_fft=config.collator.n_fft,
max_freq=config.collator.max_freq,
target_sample_rate=config.collator.target_sample_rate,
use_dB_normalization=config.collator.use_dB_normalization,
target_dB=config.collator.target_dB,
dither=config.collator.dither,
keep_transcription_text=config.collator.keep_transcription_text)
return speech_collator
def __init__(
self,
aug_file,
mean_std_filepath,
vocab_filepath,
spm_model_prefix,
random_seed=0,
unit_type="char",
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
dither=1.0,
keep_transcription_text=True):
"""SpeechCollator Collator
Args:
unit_type(str): token unit type, e.g. char, word, spm
vocab_filepath (str): vocab file path.
mean_std_filepath (str): mean and std file path, which suffix is *.npy
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
window_ms (float, optional): window size in ms. Defaults to 20.0.
n_fft (int, optional): fft points for rfft. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
target_dB (int, optional): target dB. Defaults to -20.
random_seed (int, optional): for random generator. Defaults to 0.
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
if ``keep_transcription_text`` is False, text is token ids else is raw string.
Do augmentations
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one batch.
"""
self._keep_transcription_text = keep_transcription_text
self._local_data = TarLocalData(tar2info={}, tar2object={})
self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=aug_file.read(), random_seed=random_seed)
self._normalizer = FeatureNormalizer(
mean_std_filepath) if mean_std_filepath else None
self._stride_ms = stride_ms
self._target_sample_rate = target_sample_rate
self._speech_featurizer = SpeechFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix,
specgram_type=specgram_type,
feat_dim=feat_dim,
delta_delta=delta_delta,
stride_ms=stride_ms,
window_ms=window_ms,
n_fft=n_fft,
max_freq=max_freq,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB,
dither=dither)
def _parse_tar(self, file):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result = {}
f = tarfile.open(file)
for tarinfo in f.getmembers():
result[tarinfo.name] = tarinfo
return f, result
def _subfile_from_tar(self, file):
"""Get subfile object from tar.
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath, filename = file.split(':', 1)[1].split('#', 1)
if 'tar2info' not in self._local_data.__dict__:
self._local_data.tar2info = {}
if 'tar2object' not in self._local_data.__dict__:
self._local_data.tar2object = {}
if tarpath not in self._local_data.tar2info:
object, infoes = self._parse_tar(tarpath)
self._local_data.tar2info[tarpath] = infoes
self._local_data.tar2object[tarpath] = object
return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename])
def process_utterance(self, audio_file, transcript):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
:type audio_file: str | file
:param transcript: Transcription text.
:type transcript: str
:return: Tuple of audio feature tensor and data of transcription part,
where transcription part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(audio_file), transcript)
else:
speech_segment = SpeechSegment.from_file(audio_file, transcript)
# audio augment
self._augmentation_pipeline.transform_audio(speech_segment)
specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
if self._normalizer:
specgram = self._normalizer.apply(specgram)
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
return specgram, transcript_part
def __call__(self, batch):
"""batch examples
......@@ -51,10 +264,14 @@ class SpeechCollator():
audio_lens = []
texts = []
text_lens = []
for audio, text in batch:
utts = []
for utt, audio, text in batch:
audio, text = self.process_utterance(audio, text)
#utt
utts.append(utt)
# audio
audios.append(audio.T) # [T, D]
audio_lens.append(audio.shape[1])
audios.append(audio) # [T, D]
audio_lens.append(audio.shape[0])
# text
# for training, text is token ids
# else text is string, convert to unicode ord
......@@ -75,4 +292,32 @@ class SpeechCollator():
padded_texts = pad_sequence(
texts, padding_value=IGNORE_ID).astype(np.int64)
text_lens = np.array(text_lens).astype(np.int64)
return padded_audios, audio_lens, padded_texts, text_lens
return utts, padded_audios, audio_lens, padded_texts, text_lens
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
return self._speech_featurizer.vocab_list
@property
def vocab_dict(self):
return self._speech_featurizer.vocab_dict
@property
def text_feature(self):
return self._speech_featurizer.text_feature
@property
def feature_size(self):
return self._speech_featurizer.feature_size
@property
def stride_ms(self):
return self._speech_featurizer.stride_ms
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
from collections import namedtuple
from typing import Optional
import kaldiio
import numpy as np
from yacs.config import CfgNode
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
from deepspeech.frontend.normalizer import FeatureNormalizer
from deepspeech.frontend.speech import SpeechSegment
from deepspeech.frontend.utility import IGNORE_ID
from deepspeech.io.utility import pad_sequence
from deepspeech.utils.log import Log
__all__ = ["SpeechCollator", "KaldiPrePorocessedCollator"]
logger = Log(__name__).getlog()
# namedtupe need global for pickle.
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
class SpeechCollator():
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
augmentation_config="",
random_seed=0,
mean_std_filepath="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
dither=1.0, # feature dither
keep_transcription_text=False))
if config is not None:
config.merge_from_other_cfg(default)
return default
@classmethod
def from_config(cls, config):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
"""
assert 'augmentation_config' in config.collator
assert 'keep_transcription_text' in config.collator
assert 'mean_std_filepath' in config.collator
assert 'vocab_filepath' in config.collator
assert 'specgram_type' in config.collator
assert 'n_fft' in config.collator
assert config.collator
if isinstance(config.collator.augmentation_config, (str, bytes)):
if config.collator.augmentation_config:
aug_file = io.open(
config.collator.augmentation_config,
mode='r',
encoding='utf8')
else:
aug_file = io.StringIO(initial_value='{}', newline='')
else:
aug_file = config.collator.augmentation_config
assert isinstance(aug_file, io.StringIO)
speech_collator = cls(
aug_file=aug_file,
random_seed=0,
mean_std_filepath=config.collator.mean_std_filepath,
unit_type=config.collator.unit_type,
vocab_filepath=config.collator.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix,
specgram_type=config.collator.specgram_type,
feat_dim=config.collator.feat_dim,
delta_delta=config.collator.delta_delta,
stride_ms=config.collator.stride_ms,
window_ms=config.collator.window_ms,
n_fft=config.collator.n_fft,
max_freq=config.collator.max_freq,
target_sample_rate=config.collator.target_sample_rate,
use_dB_normalization=config.collator.use_dB_normalization,
target_dB=config.collator.target_dB,
dither=config.collator.dither,
keep_transcription_text=config.collator.keep_transcription_text)
return speech_collator
def __init__(
self,
aug_file,
mean_std_filepath,
vocab_filepath,
spm_model_prefix,
random_seed=0,
unit_type="char",
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
dither=1.0,
keep_transcription_text=True):
"""SpeechCollator Collator
Args:
unit_type(str): token unit type, e.g. char, word, spm
vocab_filepath (str): vocab file path.
mean_std_filepath (str): mean and std file path, which suffix is *.npy
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
window_ms (float, optional): window size in ms. Defaults to 20.0.
n_fft (int, optional): fft points for rfft. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
target_dB (int, optional): target dB. Defaults to -20.
random_seed (int, optional): for random generator. Defaults to 0.
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
if ``keep_transcription_text`` is False, text is token ids else is raw string.
Do augmentations
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one batch.
"""
self._keep_transcription_text = keep_transcription_text
self._local_data = TarLocalData(tar2info={}, tar2object={})
self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=aug_file.read(), random_seed=random_seed)
self._normalizer = FeatureNormalizer(
mean_std_filepath) if mean_std_filepath else None
self._stride_ms = stride_ms
self._target_sample_rate = target_sample_rate
self._speech_featurizer = SpeechFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix,
specgram_type=specgram_type,
feat_dim=feat_dim,
delta_delta=delta_delta,
stride_ms=stride_ms,
window_ms=window_ms,
n_fft=n_fft,
max_freq=max_freq,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB,
dither=dither)
def _parse_tar(self, file):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result = {}
f = tarfile.open(file)
for tarinfo in f.getmembers():
result[tarinfo.name] = tarinfo
return f, result
def _subfile_from_tar(self, file):
"""Get subfile object from tar.
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath, filename = file.split(':', 1)[1].split('#', 1)
if 'tar2info' not in self._local_data.__dict__:
self._local_data.tar2info = {}
if 'tar2object' not in self._local_data.__dict__:
self._local_data.tar2object = {}
if tarpath not in self._local_data.tar2info:
object, infoes = self._parse_tar(tarpath)
self._local_data.tar2info[tarpath] = infoes
self._local_data.tar2object[tarpath] = object
return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename])
def process_utterance(self, audio_file, translation):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
:type audio_file: str | file
:param translation: translation text.
:type translation: str
:return: Tuple of audio feature tensor and data of translation part,
where translation part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(audio_file), translation)
else:
speech_segment = SpeechSegment.from_file(audio_file, translation)
# audio augment
self._augmentation_pipeline.transform_audio(speech_segment)
specgram, translation_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
if self._normalizer:
specgram = self._normalizer.apply(specgram)
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
return specgram, translation_part
def __call__(self, batch):
"""batch examples
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T)
text (List[int] or str): shape (U,)
Returns:
tuple(audio, text, audio_lens, text_lens): batched data.
audio : (B, Tmax, D)
audio_lens: (B)
text : (B, Umax)
text_lens: (B)
"""
audios = []
audio_lens = []
texts = []
text_lens = []
utts = []
for utt, audio, text in batch:
audio, text = self.process_utterance(audio, text)
#utt
utts.append(utt)
# audio
audios.append(audio) # [T, D]
audio_lens.append(audio.shape[0])
# text
# for training, text is token ids
# else text is string, convert to unicode ord
tokens = []
if self._keep_transcription_text:
assert isinstance(text, str), (type(text), text)
tokens = [ord(t) for t in text]
else:
tokens = text # token ids
tokens = tokens if isinstance(tokens, np.ndarray) else np.array(
tokens, dtype=np.int64)
texts.append(tokens)
text_lens.append(tokens.shape[0])
padded_audios = pad_sequence(
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
audio_lens = np.array(audio_lens).astype(np.int64)
padded_texts = pad_sequence(
texts, padding_value=IGNORE_ID).astype(np.int64)
text_lens = np.array(text_lens).astype(np.int64)
return utts, padded_audios, audio_lens, padded_texts, text_lens
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
return self._speech_featurizer.vocab_list
@property
def vocab_dict(self):
return self._speech_featurizer.vocab_dict
@property
def text_feature(self):
return self._speech_featurizer.text_feature
@property
def feature_size(self):
return self._speech_featurizer.feature_size
@property
def stride_ms(self):
return self._speech_featurizer.stride_ms
class TripletSpeechCollator(SpeechCollator):
def process_utterance(self, audio_file, translation, transcript):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
:type audio_file: str | file
:param translation: translation text.
:type translation: str
:return: Tuple of audio feature tensor and data of translation part,
where translation part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(audio_file), translation)
else:
speech_segment = SpeechSegment.from_file(audio_file, translation)
# audio augment
self._augmentation_pipeline.transform_audio(speech_segment)
specgram, translation_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
transcript_part = self._speech_featurizer._text_featurizer.featurize(
transcript)
if self._normalizer:
specgram = self._normalizer.apply(specgram)
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
return specgram, translation_part, transcript_part
def __call__(self, batch):
"""batch examples
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T)
text (List[int] or str): shape (U,)
Returns:
tuple(audio, text, audio_lens, text_lens): batched data.
audio : (B, Tmax, D)
audio_lens: (B)
text : (B, Umax)
text_lens: (B)
"""
audios = []
audio_lens = []
translation_text = []
translation_text_lens = []
transcription_text = []
transcription_text_lens = []
utts = []
for utt, audio, translation, transcription in batch:
audio, translation, transcription = self.process_utterance(
audio, translation, transcription)
#utt
utts.append(utt)
# audio
audios.append(audio) # [T, D]
audio_lens.append(audio.shape[0])
# text
# for training, text is token ids
# else text is string, convert to unicode ord
tokens = [[], []]
for idx, text in enumerate([translation, transcription]):
if self._keep_transcription_text:
assert isinstance(text, str), (type(text), text)
tokens[idx] = [ord(t) for t in text]
else:
tokens[idx] = text # token ids
tokens[idx] = tokens[idx] if isinstance(
tokens[idx], np.ndarray) else np.array(
tokens[idx], dtype=np.int64)
translation_text.append(tokens[0])
translation_text_lens.append(tokens[0].shape[0])
transcription_text.append(tokens[1])
transcription_text_lens.append(tokens[1].shape[0])
padded_audios = pad_sequence(
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
audio_lens = np.array(audio_lens).astype(np.int64)
padded_translation = pad_sequence(
translation_text, padding_value=IGNORE_ID).astype(np.int64)
translation_lens = np.array(translation_text_lens).astype(np.int64)
padded_transcription = pad_sequence(
transcription_text, padding_value=IGNORE_ID).astype(np.int64)
transcription_lens = np.array(transcription_text_lens).astype(np.int64)
return utts, padded_audios, audio_lens, (
padded_translation, padded_transcription), (translation_lens,
transcription_lens)
class KaldiPrePorocessedCollator(SpeechCollator):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
augmentation_config="",
random_seed=0,
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
feat_dim=0,
stride_ms=10.0,
keep_transcription_text=False))
if config is not None:
config.merge_from_other_cfg(default)
return default
@classmethod
def from_config(cls, config):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
"""
assert 'augmentation_config' in config.collator
assert 'keep_transcription_text' in config.collator
assert 'vocab_filepath' in config.collator
assert config.collator
if isinstance(config.collator.augmentation_config, (str, bytes)):
if config.collator.augmentation_config:
aug_file = io.open(
config.collator.augmentation_config,
mode='r',
encoding='utf8')
else:
aug_file = io.StringIO(initial_value='{}', newline='')
else:
aug_file = config.collator.augmentation_config
assert isinstance(aug_file, io.StringIO)
speech_collator = cls(
aug_file=aug_file,
random_seed=0,
unit_type=config.collator.unit_type,
vocab_filepath=config.collator.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix,
feat_dim=config.collator.feat_dim,
stride_ms=config.collator.stride_ms,
keep_transcription_text=config.collator.keep_transcription_text)
return speech_collator
def __init__(self,
aug_file,
vocab_filepath,
spm_model_prefix,
random_seed=0,
unit_type="char",
feat_dim=0,
stride_ms=10.0,
keep_transcription_text=True):
"""SpeechCollator Collator
Args:
unit_type(str): token unit type, e.g. char, word, spm
vocab_filepath (str): vocab file path.
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
random_seed (int, optional): for random generator. Defaults to 0.
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
if ``keep_transcription_text`` is False, text is token ids else is raw string.
Do augmentations
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one batch.
"""
self._keep_transcription_text = keep_transcription_text
self._feat_dim = feat_dim
self._stride_ms = stride_ms
self._local_data = TarLocalData(tar2info={}, tar2object={})
self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=aug_file.read(), random_seed=random_seed)
self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath,
spm_model_prefix)
def process_utterance(self, audio_file, translation):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of kaldi processed feature.
:type audio_file: str | file
:param translation: Translation text.
:type translation: str
:return: Tuple of audio feature tensor and data of translation part,
where translation part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
specgram = kaldiio.load_mat(audio_file)
specgram = specgram.transpose([1, 0])
assert specgram.shape[
0] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
self._feat_dim, specgram.shape[0])
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
if self._keep_transcription_text:
return specgram, translation
else:
text_ids = self._text_featurizer.featurize(translation)
return specgram, text_ids
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
return self._text_featurizer.vocab_size
@property
def vocab_list(self):
return self._text_featurizer.vocab_list
@property
def vocab_dict(self):
return self._text_featurizer.vocab_dict
@property
def text_feature(self):
return self._text_featurizer
@property
def feature_size(self):
return self._feat_dim
@property
def stride_ms(self):
return self._stride_ms
class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
def process_utterance(self, audio_file, translation, transcript):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of kali processed feature.
:type audio_file: str | file
:param translation: Translation text.
:type translation: str
:param transcript: Transcription text.
:type transcript: str
:return: Tuple of audio feature tensor and data of translation and transcription parts,
where translation and transcription parts could be token ids or text.
:rtype: tuple of (2darray, (list, list))
"""
specgram = kaldiio.load_mat(audio_file)
specgram = specgram.transpose([1, 0])
assert specgram.shape[
0] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
self._feat_dim, specgram.shape[0])
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
if self._keep_transcription_text:
return specgram, translation, transcript
else:
translation_text_ids = self._text_featurizer.featurize(translation)
transcript_text_ids = self._text_featurizer.featurize(transcript)
return specgram, translation_text_ids, transcript_text_ids
def __call__(self, batch):
"""batch examples
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T)
translation (List[int] or str): shape (U,)
transcription (List[int] or str): shape (V,)
Returns:
tuple(audio, text, audio_lens, text_lens): batched data.
audio : (B, Tmax, D)
audio_lens: (B)
translation_text : (B, Umax)
translation_text_lens: (B)
transcription_text : (B, Vmax)
transcription_text_lens: (B)
"""
audios = []
audio_lens = []
translation_text = []
translation_text_lens = []
transcription_text = []
transcription_text_lens = []
utts = []
for utt, audio, translation, transcription in batch:
audio, translation, transcription = self.process_utterance(
audio, translation, transcription)
#utt
utts.append(utt)
# audio
audios.append(audio) # [T, D]
audio_lens.append(audio.shape[0])
# text
# for training, text is token ids
# else text is string, convert to unicode ord
tokens = [[], []]
for idx, text in enumerate([translation, transcription]):
if self._keep_transcription_text:
assert isinstance(text, str), (type(text), text)
tokens[idx] = [ord(t) for t in text]
else:
tokens[idx] = text # token ids
tokens[idx] = tokens[idx] if isinstance(
tokens[idx], np.ndarray) else np.array(
tokens[idx], dtype=np.int64)
translation_text.append(tokens[0])
translation_text_lens.append(tokens[0].shape[0])
transcription_text.append(tokens[1])
transcription_text_lens.append(tokens[1].shape[0])
padded_audios = pad_sequence(
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
audio_lens = np.array(audio_lens).astype(np.int64)
padded_translation = pad_sequence(
translation_text, padding_value=IGNORE_ID).astype(np.int64)
translation_lens = np.array(translation_text_lens).astype(np.int64)
padded_transcription = pad_sequence(
transcription_text, padding_value=IGNORE_ID).astype(np.int64)
transcription_lens = np.array(transcription_text_lens).astype(np.int64)
return utts, padded_audios, audio_lens, (
padded_translation, padded_transcription), (translation_lens,
transcription_lens)
......@@ -11,72 +11,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import tarfile
import time
from collections import namedtuple
from typing import Optional
import numpy as np
from paddle.io import Dataset
from yacs.config import CfgNode
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
from deepspeech.frontend.normalizer import FeatureNormalizer
from deepspeech.frontend.speech import SpeechSegment
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.log import Log
__all__ = [
"ManifestDataset",
]
__all__ = ["ManifestDataset", "TripletManifestDataset"]
logger = Log(__name__).getlog()
# namedtupe need global for pickle.
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
class ManifestDataset(Dataset):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
train_manifest="",
dev_manifest="",
test_manifest="",
manifest="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
mean_std_filepath="",
augmentation_config="",
max_input_len=27.0,
min_input_len=0.0,
max_output_len=float('inf'),
min_output_len=0.0,
max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0,
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
raw_wav=True, # use raw_wav or kaldi feature
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
dither=1.0, # feature dither
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
random_seed=0,
keep_transcription_text=False,
batch_size=32, # batch size
num_workers=0, # data loader workers
sortagrad=False, # sorted in first epoch when True
shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle'
))
min_output_input_ratio=0.0, ))
if config is not None:
config.merge_from_other_cfg(default)
......@@ -94,128 +53,38 @@ class ManifestDataset(Dataset):
"""
assert 'manifest' in config.data
assert config.data.manifest
assert 'keep_transcription_text' in config.data
if isinstance(config.data.augmentation_config, (str, bytes)):
if config.data.augmentation_config:
aug_file = io.open(
config.data.augmentation_config, mode='r', encoding='utf8')
else:
aug_file = io.StringIO(initial_value='{}', newline='')
else:
aug_file = config.data.augmentation_config
assert isinstance(aug_file, io.StringIO)
dataset = cls(
manifest_path=config.data.manifest,
unit_type=config.data.unit_type,
vocab_filepath=config.data.vocab_filepath,
mean_std_filepath=config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config=aug_file.read(),
max_input_len=config.data.max_input_len,
min_input_len=config.data.min_input_len,
max_output_len=config.data.max_output_len,
min_output_len=config.data.min_output_len,
max_output_input_ratio=config.data.max_output_input_ratio,
min_output_input_ratio=config.data.min_output_input_ratio,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delta_delta,
dither=config.data.dither,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=config.data.keep_transcription_text)
min_output_input_ratio=config.data.min_output_input_ratio, )
return dataset
def __init__(self,
manifest_path,
unit_type,
vocab_filepath,
mean_std_filepath,
spm_model_prefix=None,
augmentation_config='{}',
max_input_len=float('inf'),
min_input_len=0.0,
max_output_len=float('inf'),
min_output_len=0.0,
max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0,
stride_ms=10.0,
window_ms=20.0,
n_fft=None,
max_freq=None,
target_sample_rate=16000,
specgram_type='linear',
feat_dim=None,
delta_delta=False,
dither=1.0,
use_dB_normalization=True,
target_dB=-20,
random_seed=0,
keep_transcription_text=False):
min_output_input_ratio=0.0):
"""Manifest Dataset
Args:
manifest_path (str): manifest josn file path
unit_type(str): token unit type, e.g. char, word, spm
vocab_filepath (str): vocab file path.
mean_std_filepath (str): mean and std file path, which suffix is *.npy
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05.
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
window_ms (float, optional): window size in ms. Defaults to 20.0.
n_fft (int, optional): fft points for rfft. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
target_dB (int, optional): target dB. Defaults to -20.
random_seed (int, optional): for random generator. Defaults to 0.
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
"""
super().__init__()
self._stride_ms = stride_ms
self._target_sample_rate = target_sample_rate
self._normalizer = FeatureNormalizer(
mean_std_filepath) if mean_std_filepath else None
self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=augmentation_config, random_seed=random_seed)
self._speech_featurizer = SpeechFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix,
specgram_type=specgram_type,
feat_dim=feat_dim,
delta_delta=delta_delta,
stride_ms=stride_ms,
window_ms=window_ms,
n_fft=n_fft,
max_freq=max_freq,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB,
dither=dither)
self._rng = np.random.RandomState(random_seed)
self._keep_transcription_text = keep_transcription_text
# for caching tar files info
self._local_data = TarLocalData(tar2info={}, tar2object={})
# read manifest
self._manifest = read_manifest(
......@@ -228,123 +97,22 @@ class ManifestDataset(Dataset):
min_output_input_ratio=min_output_input_ratio)
self._manifest.sort(key=lambda x: x["feat_shape"][0])
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
return self._speech_featurizer.vocab_list
@property
def vocab_dict(self):
return self._speech_featurizer.vocab_dict
@property
def text_feature(self):
return self._speech_featurizer.text_feature
@property
def feature_size(self):
return self._speech_featurizer.feature_size
@property
def stride_ms(self):
return self._speech_featurizer.stride_ms
def _parse_tar(self, file):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result = {}
f = tarfile.open(file)
for tarinfo in f.getmembers():
result[tarinfo.name] = tarinfo
return f, result
def _subfile_from_tar(self, file):
"""Get subfile object from tar.
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath, filename = file.split(':', 1)[1].split('#', 1)
if 'tar2info' not in self._local_data.__dict__:
self._local_data.tar2info = {}
if 'tar2object' not in self._local_data.__dict__:
self._local_data.tar2object = {}
if tarpath not in self._local_data.tar2info:
object, infoes = self._parse_tar(tarpath)
self._local_data.tar2info[tarpath] = infoes
self._local_data.tar2object[tarpath] = object
return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename])
def process_utterance(self, audio_file, transcript):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
:type audio_file: str | file
:param transcript: Transcription text.
:type transcript: str
:return: Tuple of audio feature tensor and data of transcription part,
where transcription part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
start_time = time.time()
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(audio_file), transcript)
else:
speech_segment = SpeechSegment.from_file(audio_file, transcript)
load_wav_time = time.time() - start_time
#logger.debug(f"load wav time: {load_wav_time}")
# audio augment
start_time = time.time()
self._augmentation_pipeline.transform_audio(speech_segment)
audio_aug_time = time.time() - start_time
#logger.debug(f"audio augmentation time: {audio_aug_time}")
start_time = time.time()
specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
if self._normalizer:
specgram = self._normalizer.apply(specgram)
feature_time = time.time() - start_time
#logger.debug(f"audio & test feature time: {feature_time}")
# specgram augment
start_time = time.time()
specgram = self._augmentation_pipeline.transform_feature(specgram)
feature_aug_time = time.time() - start_time
#logger.debug(f"audio feature augmentation time: {feature_aug_time}")
return specgram, transcript_part
def _instance_reader_creator(self, manifest):
"""
Instance reader creator. Create a callable function to produce
instances of data.
Instance: a tuple of ndarray of audio spectrogram and a list of
token indices for transcript.
"""
def __len__(self):
return len(self._manifest)
def reader():
for instance in manifest:
inst = self.process_utterance(instance["feat"],
instance["text"])
yield inst
def __getitem__(self, idx):
instance = self._manifest[idx]
return instance["utt"], instance["feat"], instance["text"]
return reader
def __len__(self):
return len(self._manifest)
class TripletManifestDataset(ManifestDataset):
"""
For Joint Training of Speech Translation and ASR.
text: translation,
text1: transcript.
"""
def __getitem__(self, idx):
instance = self._manifest[idx]
return self.process_utterance(instance["feat"], instance["text"])
return instance["utt"], instance["feat"], instance["text"], instance[
"text1"]
......@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .deepspeech2 import DeepSpeech2InferModel
from .deepspeech2 import DeepSpeech2Model
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel']
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import nn
from paddle.nn import functional as F
from deepspeech.modules.activation import brelu
from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['ConvStack', "conv_output_size"]
def conv_output_size(I, F, P, S):
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# Output size after Conv:
# By noting I the length of the input volume size,
# F the length of the filter,
# P the amount of zero padding,
# S the stride,
# then the output size O of the feature map along that dimension is given by:
# O = (I - F + Pstart + Pend) // S + 1
# When Pstart == Pend == P, we can replace Pstart + Pend by 2P.
# When Pstart == Pend == 0
# O = (I - F - S) // S
# https://iq.opengenus.org/output-size-of-convolution/
# Output height = (Input height + padding height top + padding height bottom - kernel height) / (stride height) + 1
# Output width = (Output width + padding width right + padding width left - kernel width) / (stride width) + 1
return (I - F + 2 * P - S) // S
# receptive field calculator
# https://fomoro.com/research/article/receptive-field-calculator
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# https://distill.pub/2019/computing-receptive-fields/
# Rl-1 = Sl * Rl + (Kl - Sl)
class ConvBn(nn.Layer):
"""Convolution layer with batch normalization.
:param kernel_size: The x dimension of a filter kernel. Or input a tuple for
two image dimension.
:type kernel_size: int|tuple|list
:param num_channels_in: Number of input channels.
:type num_channels_in: int
:param num_channels_out: Number of output channels.
:type num_channels_out: int
:param stride: The x dimension of the stride. Or input a tuple for two
image dimension.
:type stride: int|tuple|list
:param padding: The x dimension of the padding. Or input a tuple for two
image dimension.
:type padding: int|tuple|list
:param act: Activation type, relu|brelu
:type act: string
:return: Batch norm layer after convolution layer.
:rtype: Variable
"""
def __init__(self, num_channels_in, num_channels_out, kernel_size, stride,
padding, act):
super().__init__()
assert len(kernel_size) == 2
assert len(stride) == 2
assert len(padding) == 2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.conv = nn.Conv2D(
num_channels_in,
num_channels_out,
kernel_size=kernel_size,
stride=stride,
padding=padding,
weight_attr=None,
bias_attr=False,
data_format='NCHW')
self.bn = nn.BatchNorm2D(
num_channels_out,
weight_attr=None,
bias_attr=None,
data_format='NCHW')
self.act = F.relu if act == 'relu' else brelu
def forward(self, x, x_len):
"""
x(Tensor): audio, shape [B, C, D, T]
"""
x = self.conv(x)
x = self.bn(x)
x = self.act(x)
x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]
) // self.stride[1] + 1
# reset padding part to 0
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
# TODO(Hui Zhang): not support bool multiply
# masks = masks.type_as(x)
masks = masks.astype(x.dtype)
x = x.multiply(masks)
return x, x_len
class ConvStack(nn.Layer):
"""Convolution group with stacked convolution layers.
:param feat_size: audio feature dim.
:type feat_size: int
:param num_stacks: Number of stacked convolution layers.
:type num_stacks: int
"""
def __init__(self, feat_size, num_stacks):
super().__init__()
self.feat_size = feat_size # D
self.num_stacks = num_stacks
self.conv_in = ConvBn(
num_channels_in=1,
num_channels_out=32,
kernel_size=(41, 11), #[D, T]
stride=(2, 3),
padding=(20, 5),
act='brelu')
out_channel = 32
convs = [
ConvBn(
num_channels_in=32,
num_channels_out=out_channel,
kernel_size=(21, 11),
stride=(2, 1),
padding=(10, 5),
act='brelu') for i in range(num_stacks - 1)
]
self.conv_stack = nn.LayerList(convs)
# conv output feat_dim
output_height = (feat_size - 1) // 2 + 1
for i in range(self.num_stacks - 1):
output_height = (output_height - 1) // 2 + 1
self.output_height = out_channel * output_height
def forward(self, x, x_len):
"""
x: shape [B, C, D, T]
x_len : shape [B]
"""
x, x_len = self.conv_in(x, x_len)
for i, conv in enumerate(self.conv_stack):
x, x_len = conv(x, x_len)
return x, x_len
......@@ -18,16 +18,16 @@ import paddle
from paddle import nn
from yacs.config import CfgNode
from deepspeech.modules.conv import ConvStack
from deepspeech.models.ds2.conv import ConvStack
from deepspeech.models.ds2.rnn import RNNStack
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.modules.rnn import RNNStack
from deepspeech.utils import checkpoint
from deepspeech.utils import layer_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2Model']
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel']
class CRNNEncoder(nn.Layer):
......@@ -117,7 +117,7 @@ class DeepSpeech2Model(nn.Layer):
:type share_weights: bool
:return: A tuple of an output unnormalized log probability layer (
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
:rtype: tuple of LayerOutput
"""
@classmethod
......@@ -198,36 +198,57 @@ class DeepSpeech2Model(nn.Layer):
cutoff_top_n, num_processes)
@classmethod
def from_pretrained(cls, dataset, config, checkpoint_path):
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Parameters
----------
dataset: paddle.io.Dataset
dataloader: paddle.io.DataLoader
config: yacs.config.CfgNode
model configs
checkpoint_path: Path or str
the path of pretrained model checkpoint, without extension name
Returns
-------
DeepSpeech2Model
The model built from pretrained result.
"""
model = cls(feat_size=dataset.feature_size,
dict_size=dataset.vocab_size,
model = cls(feat_size=dataloader.collate_fn.feature_size,
dict_size=dataloader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
infos = checkpoint.load_parameters(
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
return model
@classmethod
def from_config(cls, config):
"""Build a DeepSpeec2Model from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2Model
The model built from config.
"""
model = cls(feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
use_gru=config.use_gru,
share_rnn_weights=config.share_rnn_weights)
return model
class DeepSpeech2InferModel(DeepSpeech2Model):
def __init__(self,
......@@ -260,3 +281,15 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
return probs
def export(self):
static_model = paddle.jit.to_static(
self,
input_spec=[
paddle.static.InputSpec(
shape=[None, None, self.encoder.feat_size],
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
])
return static_model
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from deepspeech.modules.activation import brelu
from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['RNNStack']
class RNNCell(nn.RNNCellBase):
r"""
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
computes the outputs and updates states.
The formula used is as follows:
.. math::
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
y_{t} & = h_{t}
where :math:`act` is for :attr:`activation`.
"""
def __init__(self,
hidden_size: int,
activation="tanh",
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
if activation not in ["tanh", "relu", "brelu"]:
raise ValueError(
"activation for SimpleRNNCell should be tanh or relu, "
"but get {}".format(activation))
self.activation = activation
self._activation_fn = paddle.tanh \
if activation == "tanh" \
else F.relu
if activation == 'brelu':
self._activation_fn = brelu
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_h = states
i2h = inputs
if self.bias_ih is not None:
i2h += self.bias_ih
h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h2h += self.bias_hh
h = self._activation_fn(i2h + h2h)
return h, h
@property
def state_shape(self):
return (self.hidden_size, )
class GRUCell(nn.RNNCellBase):
r"""
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
it computes the outputs and updates states.
The formula for GRU used is as follows:
.. math::
r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})
z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator.
"""
def __init__(self,
input_size: int,
hidden_size: int,
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(3 * hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(3 * hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
self.input_size = input_size
self._gate_activation = F.sigmoid
self._activation = paddle.tanh
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_hidden = states
x_gates = inputs
if self.bias_ih is not None:
x_gates = x_gates + self.bias_ih
h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h_gates = h_gates + self.bias_hh
x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)
h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)
r = self._gate_activation(x_r + h_r)
z = self._gate_activation(x_z + h_z)
c = self._activation(x_c + r * h_c) # apply reset gate after mm
h = (pre_hidden - c) * z + c
# https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru
return h, h
@property
def state_shape(self):
r"""
The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch
size would be automatically inserted into shape). The shape corresponds
to the shape of :math:`h_{t-1}`.
"""
return (self.hidden_size, )
class BiRNNWithBN(nn.Layer):
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param size: Dimension of RNN cells.
:type size: int
:param share_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
:type share_weights: bool
:return: Bidirectional simple rnn layer.
:rtype: Variable
"""
def __init__(self, i_size: int, h_size: int, share_weights: bool):
super().__init__()
self.share_weights = share_weights
if self.share_weights:
#input-hidden weights shared between bi-directional rnn.
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
# batch norm is only performed on input-state projection
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = self.fw_fc
self.bw_bn = self.fw_bn
else:
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class BiGRUWithBN(nn.Layer):
"""Bidirectonal gru layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param name: Name of the layer.
:type name: string
:param input: Input layer.
:type input: Variable
:param size: Dimension of GRU cells.
:type size: int
:param act: Activation type.
:type act: string
:return: Bidirectional GRU layer.
:rtype: Variable
"""
def __init__(self, i_size: int, h_size: int):
super().__init__()
hidden_size = h_size * 3
self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x, x_len):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class RNNStack(nn.Layer):
"""RNN group with stacked bidirectional simple RNN or GRU layers.
:param input: Input layer.
:type input: Variable
:param size: Dimension of RNN cells in each layer.
:type size: int
:param num_stacks: Number of stacked rnn layers.
:type num_stacks: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:param share_rnn_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
It is only available when use_gru=False.
:type share_weights: bool
:return: Output layer of the RNN group.
:rtype: Variable
"""
def __init__(self,
i_size: int,
h_size: int,
num_stacks: int,
use_gru: bool,
share_rnn_weights: bool):
super().__init__()
rnn_stacks = []
for i in range(num_stacks):
if use_gru:
#default:GRU using tanh
rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size))
else:
rnn_stacks.append(
BiRNNWithBN(
i_size=i_size,
h_size=h_size,
share_weights=share_rnn_weights))
i_size = h_size * 2
self.rnn_stacks = nn.ModuleList(rnn_stacks)
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
"""
x: shape [B, T, D]
x_len: shpae [B]
"""
for i, rnn in enumerate(self.rnn_stacks):
x, x_len = rnn(x, x_len)
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1]
# TODO(Hui Zhang): not support bool multiply
masks = masks.astype(x.dtype)
x = x.multiply(masks)
return x, x_len
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .deepspeech2 import DeepSpeech2InferModelOnline
from .deepspeech2 import DeepSpeech2ModelOnline
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline']
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.subsampling import Conv2dSubsampling4
class Conv2dSubsampling4Online(Conv2dSubsampling4):
def __init__(self, idim: int, odim: int, dropout_rate: float):
super().__init__(idim, odim, dropout_rate, None)
self.output_dim = ((idim - 1) // 2 - 1) // 2 * odim
self.receptive_field_length = 2 * (
3 - 1) + 3 # stride_1 * (kernel_size_2 - 1) + kerel_size_1
def forward(self, x: paddle.Tensor,
x_len: paddle.Tensor) -> [paddle.Tensor, paddle.Tensor]:
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
#b, c, t, f = paddle.shape(x) #not work under jit
x = x.transpose([0, 2, 1, 3]).reshape([0, 0, -1])
x_len = ((x_len - 1) // 2 - 1) // 2
return x, x_len
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Online Model"""
from typing import Optional
import paddle
import paddle.nn.functional as F
from paddle import nn
from yacs.config import CfgNode
from deepspeech.models.ds2_online.conv import Conv2dSubsampling4Online
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.utils import layer_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModeOnline']
class CRNNEncoder(nn.Layer):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False):
super().__init__()
self.rnn_size = rnn_size
self.feat_size = feat_size # 161 for linear
self.dict_size = dict_size
self.num_rnn_layers = num_rnn_layers
self.num_fc_layers = num_fc_layers
self.rnn_direction = rnn_direction
self.fc_layers_size_list = fc_layers_size_list
self.use_gru = use_gru
self.conv = Conv2dSubsampling4Online(feat_size, 32, dropout_rate=0.0)
self.output_dim = self.conv.output_dim
i_size = self.conv.output_dim
self.rnn = nn.LayerList()
self.layernorm_list = nn.LayerList()
self.fc_layers_list = nn.LayerList()
if rnn_direction == 'bidirect' or rnn_direction == 'bidirectional':
layernorm_size = 2 * rnn_size
elif rnn_direction == 'forward':
layernorm_size = rnn_size
else:
raise Exception("Wrong rnn direction")
for i in range(0, num_rnn_layers):
if i == 0:
rnn_input_size = i_size
else:
rnn_input_size = layernorm_size
if use_gru == True:
self.rnn.append(
nn.GRU(
input_size=rnn_input_size,
hidden_size=rnn_size,
num_layers=1,
direction=rnn_direction))
else:
self.rnn.append(
nn.LSTM(
input_size=rnn_input_size,
hidden_size=rnn_size,
num_layers=1,
direction=rnn_direction))
self.layernorm_list.append(nn.LayerNorm(layernorm_size))
self.output_dim = layernorm_size
fc_input_size = layernorm_size
for i in range(self.num_fc_layers):
self.fc_layers_list.append(
nn.Linear(fc_input_size, fc_layers_size_list[i]))
fc_input_size = fc_layers_size_list[i]
self.output_dim = fc_layers_size_list[i]
@property
def output_size(self):
return self.output_dim
def forward(self, x, x_lens, init_state_h_box=None, init_state_c_box=None):
"""Compute Encoder outputs
Args:
x (Tensor): [B, feature_size, D]
x_lens (Tensor): [B]
init_state_h_box(Tensor): init_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
init_state_c_box(Tensor): init_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
Returns:
x (Tensor): encoder outputs, [B, size, D]
x_lens (Tensor): encoder length, [B]
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
"""
if init_state_h_box is not None:
init_state_list = None
if self.use_gru == True:
init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0)
init_state_list = init_state_h_list
else:
init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0)
init_state_c_list = paddle.split(
init_state_c_box, self.num_rnn_layers, axis=0)
init_state_list = [(init_state_h_list[i], init_state_c_list[i])
for i in range(self.num_rnn_layers)]
else:
init_state_list = [None] * self.num_rnn_layers
x, x_lens = self.conv(x, x_lens)
final_chunk_state_list = []
for i in range(0, self.num_rnn_layers):
x, final_state = self.rnn[i](x, init_state_list[i],
x_lens) #[B, T, D]
final_chunk_state_list.append(final_state)
x = self.layernorm_list[i](x)
for i in range(self.num_fc_layers):
x = self.fc_layers_list[i](x)
x = F.relu(x)
if self.use_gru == True:
final_chunk_state_h_box = paddle.concat(
final_chunk_state_list, axis=0)
final_chunk_state_c_box = init_state_c_box #paddle.zeros_like(final_chunk_state_h_box)
else:
final_chunk_state_h_list = [
final_chunk_state_list[i][0] for i in range(self.num_rnn_layers)
]
final_chunk_state_c_list = [
final_chunk_state_list[i][1] for i in range(self.num_rnn_layers)
]
final_chunk_state_h_box = paddle.concat(
final_chunk_state_h_list, axis=0)
final_chunk_state_c_box = paddle.concat(
final_chunk_state_c_list, axis=0)
return x, x_lens, final_chunk_state_h_box, final_chunk_state_c_box
def forward_chunk_by_chunk(self, x, x_lens, decoder_chunk_size=8):
"""Compute Encoder outputs
Args:
x (Tensor): [B, T, D]
x_lens (Tensor): [B]
decoder_chunk_size: The chunk size of decoder
Returns:
eouts_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
"""
subsampling_rate = self.conv.subsampling_rate
receptive_field_length = self.conv.receptive_field_length
chunk_size = (decoder_chunk_size - 1
) * subsampling_rate + receptive_field_length
chunk_stride = subsampling_rate * decoder_chunk_size
max_len = x.shape[1]
assert (chunk_size <= max_len)
eouts_chunk_list = []
eouts_chunk_lens_list = []
padding_len = chunk_stride - (max_len - chunk_size) % chunk_stride
padding = paddle.zeros((x.shape[0], padding_len, x.shape[2]))
padded_x = paddle.concat([x, padding], axis=1)
num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1
num_chunk = int(num_chunk)
chunk_state_h_box = None
chunk_state_c_box = None
final_state_h_box = None
final_state_c_box = None
for i in range(0, num_chunk):
start = i * chunk_stride
end = start + chunk_size
x_chunk = padded_x[:, start:end, :]
x_len_left = paddle.where(x_lens - i * chunk_stride < 0,
paddle.zeros_like(x_lens),
x_lens - i * chunk_stride)
x_chunk_len_tmp = paddle.ones_like(x_lens) * chunk_size
x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp,
x_len_left, x_chunk_len_tmp)
eouts_chunk, eouts_chunk_lens, chunk_state_h_box, chunk_state_c_box = self.forward(
x_chunk, x_chunk_lens, chunk_state_h_box, chunk_state_c_box)
eouts_chunk_list.append(eouts_chunk)
eouts_chunk_lens_list.append(eouts_chunk_lens)
final_state_h_box = chunk_state_h_box
final_state_c_box = chunk_state_c_box
return eouts_chunk_list, eouts_chunk_lens_list, final_state_h_box, final_state_c_box
class DeepSpeech2ModelOnline(nn.Layer):
"""The DeepSpeech2 network structure for online.
:param audio_data: Audio spectrogram data layer.
:type audio_data: Variable
:param text_data: Transcription text data layer.
:type text_data: Variable
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
:type num_conv_layers: int
:param num_rnn_layers: Number of stacking RNN layers.
:type num_rnn_layers: int
:param rnn_size: RNN layer size (dimension of RNN cells).
:type rnn_size: int
:param num_fc_layers: Number of stacking FC layers.
:type num_fc_layers: int
:param fc_layers_size_list: The list of FC layer sizes.
:type fc_layers_size_list: [int,]
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:return: A tuple of an output unnormalized log probability layer (
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
"""
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=4, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=True, #Use gru if set True. Use simple rnn if set False.
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_direction=rnn_direction,
num_fc_layers=num_fc_layers,
fc_layers_size_list=fc_layers_size_list,
rnn_size=rnn_size,
use_gru=use_gru)
self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab
enc_n_units=self.encoder.output_size,
blank_id=0, # first token is <blank>
dropout_rate=0.0,
reduction=True, # sum
batch_average=True) # sum / batch_size
def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss
Args:
audio (Tenosr): [B, T, D]
audio_len (Tensor): [B]
text (Tensor): [B, U]
text_len (Tensor): [B]
Returns:
loss (Tenosr): [1]
"""
eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder(
audio, audio_len, None, None)
loss = self.decoder(eouts, eouts_len, text, text_len)
return loss
@paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder(
audio, audio_len, None, None)
probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Parameters
----------
dataloader: paddle.io.DataLoader
config: yacs.config.CfgNode
model configs
checkpoint_path: Path or str
the path of pretrained model checkpoint, without extension name
Returns
-------
DeepSpeech2ModelOnline
The model built from pretrained result.
"""
model = cls(feat_size=dataloader.collate_fn.feature_size,
dict_size=dataloader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
rnn_direction=config.model.rnn_direction,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
use_gru=config.model.use_gru)
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
return model
@classmethod
def from_config(cls, config):
"""Build a DeepSpeec2ModelOnline from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2ModelOnline
The model built from config.
"""
model = cls(feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
rnn_direction=config.rnn_direction,
num_fc_layers=config.num_fc_layers,
fc_layers_size_list=config.fc_layers_size_list,
use_gru=config.use_gru)
return model
class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False):
super().__init__(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
rnn_direction=rnn_direction,
num_fc_layers=num_fc_layers,
fc_layers_size_list=fc_layers_size_list,
use_gru=use_gru)
def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box,
chunk_state_c_box):
eouts_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box = self.encoder(
audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box)
probs_chunk = self.decoder.softmax(eouts_chunk)
return probs_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box
def export(self):
static_model = paddle.jit.to_static(
self,
input_spec=[
paddle.static.InputSpec(
shape=[None, None,
self.encoder.feat_size], #[B, chunk_size, feat_dim]
dtype='float32'),
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
paddle.static.InputSpec(
shape=[None, None, None], dtype='float32'),
paddle.static.InputSpec(
shape=[None, None, None], dtype='float32')
])
return static_model
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""U2 ASR Model
Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
(https://arxiv.org/pdf/2012.05481.pdf)
"""
import sys
......@@ -83,7 +83,7 @@ class U2BaseModel(nn.Module):
# cnn_module_kernel=15,
# activation_type='swish',
# pos_enc_layer_type='rel_pos',
# selfattention_layer_type='rel_selfattn',
# selfattention_layer_type='rel_selfattn',
))
# decoder related
default.decoder = 'transformer'
......@@ -244,8 +244,8 @@ class U2BaseModel(nn.Module):
simulate_streaming (bool, optional): streaming or not. Defaults to False.
Returns:
Tuple[paddle.Tensor, paddle.Tensor]:
encoder hiddens (B, Tmax, D),
Tuple[paddle.Tensor, paddle.Tensor]:
encoder hiddens (B, Tmax, D),
encoder hiddens mask (B, 1, Tmax).
"""
# Let's assume B = batch_size
......@@ -399,6 +399,7 @@ class U2BaseModel(nn.Module):
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
batch_size = speech.shape[0]
# Let's assume B = batch_size
# encoder_out: (B, maxlen, encoder_dim)
# encoder_mask: (B, 1, Tmax)
......@@ -410,10 +411,12 @@ class U2BaseModel(nn.Module):
# encoder_out_lens = encoder_mask.squeeze(1).sum(1)
encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1)
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen)
topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen)
hyps = [hyp.tolist() for hyp in topk_index]
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
return hyps
......@@ -449,6 +452,7 @@ class U2BaseModel(nn.Module):
batch_size = speech.shape[0]
# For CTC prefix beam search, we only support batch_size=1
assert batch_size == 1
# Let's assume B = batch_size and N = beam_size
# 1. Encoder forward and get CTC score
encoder_out, encoder_mask = self._forward_encoder(
......@@ -458,7 +462,9 @@ class U2BaseModel(nn.Module):
maxlen = encoder_out.size(1)
ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
......@@ -498,6 +504,7 @@ class U2BaseModel(nn.Module):
key=lambda x: log_add(list(x[1])),
reverse=True)
cur_hyps = next_hyps[:beam_size]
hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps]
return hyps, encoder_out
......@@ -561,12 +568,13 @@ class U2BaseModel(nn.Module):
batch_size = speech.shape[0]
# For attention rescoring we only support batch_size=1
assert batch_size == 1
# encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size
# len(hyps) = beam_size, encoder_out: (1, maxlen, encoder_dim)
hyps, encoder_out = self._ctc_prefix_beam_search(
speech, speech_lengths, beam_size, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming)
assert len(hyps) == beam_size
hyps_pad = pad_sequence([
paddle.to_tensor(hyp[0], place=device, dtype=paddle.long)
for hyp in hyps
......@@ -576,49 +584,54 @@ class U2BaseModel(nn.Module):
dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.size(1)), dtype=paddle.bool)
decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
decoder_out = decoder_out.numpy()
# Only use decoder score for rescoring
best_score = -float('inf')
best_index = 0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for i, hyp in enumerate(hyps):
score = 0.0
for j, w in enumerate(hyp[0]):
score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.eos]
# add ctc score
# add ctc score (which in ln domain)
score += hyp[1] * ctc_weight
if score > best_score:
best_score = score
best_index = i
return hyps[best_index][0]
@jit.export
#@jit.export
def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return self.encoder.embed.subsampling_rate
@jit.export
#@jit.export
def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model
"""
return self.encoder.embed.right_context
@jit.export
#@jit.export
def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model
"""
return self.sos
@jit.export
#@jit.export
def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model
"""
......@@ -654,12 +667,14 @@ class U2BaseModel(nn.Module):
xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
@jit.export
# @jit.export([
# paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D]
# ])
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
Args:
xs (paddle.Tensor): encoder output
xs (paddle.Tensor): encoder output, (B, T, D)
Returns:
paddle.Tensor: activation before ctc
"""
......@@ -717,8 +732,8 @@ class U2BaseModel(nn.Module):
feats (Tenosr): audio features, (B, T, D)
feats_lengths (Tenosr): (B)
text_feature (TextFeaturizer): text feature object.
decoding_method (str): decoding mode, e.g.
'attention', 'ctc_greedy_search',
decoding_method (str): decoding mode, e.g.
'attention', 'ctc_greedy_search',
'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path (str): lm path.
beam_alpha (float): lm weight.
......@@ -726,19 +741,19 @@ class U2BaseModel(nn.Module):
beam_size (int): beam size for search
cutoff_prob (float): for prune.
cutoff_top_n (int): for prune.
num_processes (int):
num_processes (int):
ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0.
decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here.
num_decoding_left_chunks (int, optional):
0: used for training, it's prohibited here.
num_decoding_left_chunks (int, optional):
number of left chunks for decoding. Defaults to -1.
simulate_streaming (bool, optional): simulate streaming inference. Defaults to False.
Raises:
ValueError: when not support decoding_method.
Returns:
List[List[int]]: transcripts.
"""
......@@ -819,7 +834,7 @@ class U2Model(U2BaseModel):
ValueError: raise when using not support encoder type.
Returns:
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
"""
if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'],
......@@ -876,11 +891,11 @@ class U2Model(U2BaseModel):
return model
@classmethod
def from_pretrained(cls, dataset, config, checkpoint_path):
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Args:
dataset (paddle.io.Dataset): not used.
dataloader (paddle.io.DataLoader): not used.
config (yacs.config.CfgNode): model configs
checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name
......@@ -888,13 +903,13 @@ class U2Model(U2BaseModel):
DeepSpeech2Model: The model built from pretrained result.
"""
config.defrost()
config.input_dim = dataset.feature_size
config.output_dim = dataset.vocab_size
config.input_dim = dataloader.collate_fn.feature_size
config.output_dim = dataloader.collate_fn.vocab_size
config.freeze()
model = cls.from_config(config)
if checkpoint_path:
infos = checkpoint.load_parameters(
infos = checkpoint.Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""U2 ASR Model
Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
(https://arxiv.org/pdf/2012.05481.pdf)
"""
import time
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import paddle
from paddle import jit
from paddle import nn
from yacs.config import CfgNode
from deepspeech.frontend.utility import IGNORE_ID
from deepspeech.frontend.utility import load_cmvn
from deepspeech.modules.cmvn import GlobalCMVN
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.modules.decoder import TransformerDecoder
from deepspeech.modules.encoder import ConformerEncoder
from deepspeech.modules.encoder import TransformerEncoder
from deepspeech.modules.loss import LabelSmoothingLoss
from deepspeech.modules.mask import mask_finished_preds
from deepspeech.modules.mask import mask_finished_scores
from deepspeech.modules.mask import subsequent_mask
from deepspeech.utils import checkpoint
from deepspeech.utils import layer_tools
from deepspeech.utils.log import Log
from deepspeech.utils.tensor_utils import add_sos_eos
from deepspeech.utils.tensor_utils import th_accuracy
__all__ = ["U2STModel", "U2STInferModel"]
logger = Log(__name__).getlog()
class U2STBaseModel(nn.Module):
"""CTC-Attention hybrid Encoder-Decoder model"""
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
# network architecture
default = CfgNode()
# allow add new item when merge_with_file
default.cmvn_file = ""
default.cmvn_file_type = "json"
default.input_dim = 0
default.output_dim = 0
# encoder related
default.encoder = 'transformer'
default.encoder_conf = CfgNode(
dict(
output_size=256, # dimension of attention
attention_heads=4,
linear_units=2048, # the number of units of position-wise feed forward
num_blocks=12, # the number of encoder blocks
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.0,
input_layer='conv2d', # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before=True,
# use_cnn_module=True,
# cnn_module_kernel=15,
# activation_type='swish',
# pos_enc_layer_type='rel_pos',
# selfattention_layer_type='rel_selfattn',
))
# decoder related
default.decoder = 'transformer'
default.decoder_conf = CfgNode(
dict(
attention_heads=4,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
self_attention_dropout_rate=0.0,
src_attention_dropout_rate=0.0, ))
# hybrid CTC/attention
default.model_conf = CfgNode(
dict(
asr_weight=0.0,
ctc_weight=0.0,
lsm_weight=0.1, # label smoothing option
length_normalized_loss=False, ))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self,
vocab_size: int,
encoder: TransformerEncoder,
st_decoder: TransformerDecoder,
decoder: TransformerDecoder=None,
ctc: CTCDecoder=None,
ctc_weight: float=0.0,
asr_weight: float=0.0,
ignore_id: int=IGNORE_ID,
lsm_weight: float=0.0,
length_normalized_loss: bool=False):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
super().__init__()
# note that eos is the same as sos (equivalent ID)
self.sos = vocab_size - 1
self.eos = vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.asr_weight = asr_weight
self.encoder = encoder
self.st_decoder = st_decoder
self.decoder = decoder
self.ctc = ctc
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss, )
def forward(
self,
speech: paddle.Tensor,
speech_lengths: paddle.Tensor,
text: paddle.Tensor,
text_lengths: paddle.Tensor,
asr_text: paddle.Tensor=None,
asr_text_lengths: paddle.Tensor=None,
) -> Tuple[Optional[paddle.Tensor], Optional[paddle.Tensor], Optional[
paddle.Tensor]]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
Returns:
total_loss, attention_loss, ctc_loss
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
text.shape, text_lengths.shape)
# 1. Encoder
start = time.time()
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_time = time.time() - start
#logger.debug(f"encoder time: {encoder_time}")
#TODO(Hui Zhang): sum not support bool type
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(
1) #[B, 1, T] -> [B]
# 2a. ST-decoder branch
start = time.time()
loss_st, acc_st = self._calc_st_loss(encoder_out, encoder_mask, text,
text_lengths)
decoder_time = time.time() - start
loss_asr_att = None
loss_asr_ctc = None
# 2b. ASR Attention-decoder branch
if self.asr_weight > 0.:
if self.ctc_weight != 1.0:
start = time.time()
loss_asr_att, acc_att = self._calc_att_loss(
encoder_out, encoder_mask, asr_text, asr_text_lengths)
decoder_time = time.time() - start
# 2c. CTC branch
if self.ctc_weight != 0.0:
start = time.time()
loss_asr_ctc = self.ctc(encoder_out, encoder_out_lens, asr_text,
asr_text_lengths)
ctc_time = time.time() - start
if loss_asr_ctc is None:
loss_asr = loss_asr_att
elif loss_asr_att is None:
loss_asr = loss_asr_ctc
else:
loss_asr = self.ctc_weight * loss_asr_ctc + (1 - self.ctc_weight
) * loss_asr_att
loss = self.asr_weight * loss_asr + (1 - self.asr_weight) * loss_st
else:
loss = loss_st
return loss, loss_st, loss_asr_att, loss_asr_ctc
def _calc_st_loss(
self,
encoder_out: paddle.Tensor,
encoder_mask: paddle.Tensor,
ys_pad: paddle.Tensor,
ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]:
"""Calc attention loss.
Args:
encoder_out (paddle.Tensor): [B, Tmax, D]
encoder_mask (paddle.Tensor): [B, 1, Tmax]
ys_pad (paddle.Tensor): [B, Umax]
ys_pad_lens (paddle.Tensor): [B]
Returns:
Tuple[paddle.Tensor, float]: attention_loss, accuracy rate
"""
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad,
ys_in_lens)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id, )
return loss_att, acc_att
def _calc_att_loss(
self,
encoder_out: paddle.Tensor,
encoder_mask: paddle.Tensor,
ys_pad: paddle.Tensor,
ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]:
"""Calc attention loss.
Args:
encoder_out (paddle.Tensor): [B, Tmax, D]
encoder_mask (paddle.Tensor): [B, 1, Tmax]
ys_pad (paddle.Tensor): [B, Umax]
ys_pad_lens (paddle.Tensor): [B]
Returns:
Tuple[paddle.Tensor, float]: attention_loss, accuracy rate
"""
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad,
ys_in_lens)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id, )
return loss_att, acc_att
def _forward_encoder(
self,
speech: paddle.Tensor,
speech_lengths: paddle.Tensor,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Encoder pass.
Args:
speech (paddle.Tensor): [B, Tmax, D]
speech_lengths (paddle.Tensor): [B]
decoding_chunk_size (int, optional): chuck size. Defaults to -1.
num_decoding_left_chunks (int, optional): nums chunks. Defaults to -1.
simulate_streaming (bool, optional): streaming or not. Defaults to False.
Returns:
Tuple[paddle.Tensor, paddle.Tensor]:
encoder hiddens (B, Tmax, D),
encoder hiddens mask (B, 1, Tmax).
"""
# Let's assume B = batch_size
# 1. Encoder
if simulate_streaming and decoding_chunk_size > 0:
encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk(
speech,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks
) # (B, maxlen, encoder_dim)
else:
encoder_out, encoder_mask = self.encoder(
speech,
speech_lengths,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks
) # (B, maxlen, encoder_dim)
return encoder_out, encoder_mask
def translate(
self,
speech: paddle.Tensor,
speech_lengths: paddle.Tensor,
beam_size: int=10,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False, ) -> paddle.Tensor:
""" Apply beam search on attention decoder
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
speech_length (paddle.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
paddle.Tensor: decoding result, (batch, max_result_len)
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
device = speech.place
batch_size = speech.shape[0]
# Let's assume B = batch_size and N = beam_size
# 1. Encoder
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
encoder_dim = encoder_out.size(2)
running_size = batch_size * beam_size
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
encoder_mask = encoder_mask.unsqueeze(1).repeat(
1, beam_size, 1, 1).view(running_size, 1,
maxlen) # (B*N, 1, max_len)
hyps = paddle.ones(
[running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1)
# log scale score
scores = paddle.to_tensor(
[0.0] + [-float('inf')] * (beam_size - 1), dtype=paddle.float)
scores = scores.to(device).repeat(batch_size).unsqueeze(1).to(
device) # (B*N, 1)
end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1)
cache: Optional[List[paddle.Tensor]] = None
# 2. Decoder forward step by step
for i in range(1, maxlen + 1):
# Stop if all batch and all beam produce eos
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break
# 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i)
# logp: (B*N, vocab)
logp, cache = self.st_decoder.forward_one_step(
encoder_out, encoder_mask, hyps, hyps_mask, cache)
# 2.2 First beam prune: select topk best prob at current time
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos)
# 2.3 Seconde beam prune: select topk score with history
scores = scores + top_k_logp # (B*N, N), broadcast add
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N)
scores, offset_k_index = scores.topk(k=beam_size) # (B, N)
scores = scores.view(-1, 1) # (B*N, 1)
# 2.4. Compute base index in top_k_index,
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
# then find offset_k_index in top_k_index
base_k_index = paddle.arange(batch_size).view(-1, 1).repeat(
1, beam_size) # (B, N)
base_k_index = base_k_index * beam_size * beam_size
best_k_index = base_k_index.view(-1) + offset_k_index.view(
-1) # (B*N)
# 2.5 Update best hyps
best_k_pred = paddle.index_select(
top_k_index.view(-1), index=best_k_index, axis=0) # (B*N)
best_hyps_index = best_k_index // beam_size
last_best_k_hyps = paddle.index_select(
hyps, index=best_hyps_index, axis=0) # (B*N, i)
hyps = paddle.cat(
(last_best_k_hyps, best_k_pred.view(-1, 1)),
dim=1) # (B*N, i+1)
# 2.6 Update end flag
end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1)
# 3. Select best of best
scores = scores.view(batch_size, beam_size)
# TODO: length normalization
best_index = paddle.argmax(scores, axis=-1).long() # (B)
best_hyps_index = best_index + paddle.arange(
batch_size, dtype=paddle.long) * beam_size
best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0)
best_hyps = best_hyps[:, 1:]
return best_hyps
@jit.export
def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return self.encoder.embed.subsampling_rate
@jit.export
def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model
"""
return self.encoder.embed.right_context
@jit.export
def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model
"""
return self.sos
@jit.export
def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model
"""
return self.eos
@jit.export
def forward_encoder_chunk(
self,
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
subsampling_cache: Optional[paddle.Tensor]=None,
elayers_output_cache: Optional[List[paddle.Tensor]]=None,
conformer_cnn_cache: Optional[List[paddle.Tensor]]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[
paddle.Tensor]]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
Args:
xs (paddle.Tensor): chunk input
subsampling_cache (Optional[paddle.Tensor]): subsampling cache
elayers_output_cache (Optional[List[paddle.Tensor]]):
transformer/conformer encoder layers output cache
conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer
cnn cache
Returns:
paddle.Tensor: output, it ranges from time 0 to current chunk.
paddle.Tensor: subsampling cache
List[paddle.Tensor]: attention cache
List[paddle.Tensor]: conformer cnn cache
"""
return self.encoder.forward_chunk(
xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
@jit.export
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
Args:
xs (paddle.Tensor): encoder output
Returns:
paddle.Tensor: activation before ctc
"""
return self.ctc.log_softmax(xs)
@jit.export
def forward_attention_decoder(
self,
hyps: paddle.Tensor,
hyps_lens: paddle.Tensor,
encoder_out: paddle.Tensor, ) -> paddle.Tensor:
""" Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output
Args:
hyps (paddle.Tensor): hyps from ctc prefix beam search, already
pad sos at the begining, (B, T)
hyps_lens (paddle.Tensor): length of each hyp in hyps, (B)
encoder_out (paddle.Tensor): corresponding encoder output, (B=1, T, D)
Returns:
paddle.Tensor: decoder output, (B, L)
"""
assert encoder_out.size(0) == 1
num_hyps = hyps.size(0)
assert hyps_lens.size(0) == num_hyps
encoder_out = encoder_out.repeat(num_hyps, 1, 1)
# (B, 1, T)
encoder_mask = paddle.ones(
[num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool)
# (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1)
return decoder_out
@paddle.no_grad()
def decode(self,
feats: paddle.Tensor,
feats_lengths: paddle.Tensor,
text_feature: Dict[str, int],
decoding_method: str,
lang_model_path: str,
beam_alpha: float,
beam_beta: float,
beam_size: int,
cutoff_prob: float,
cutoff_top_n: int,
num_processes: int,
ctc_weight: float=0.0,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False):
"""u2 decoding.
Args:
feats (Tenosr): audio features, (B, T, D)
feats_lengths (Tenosr): (B)
text_feature (TextFeaturizer): text feature object.
decoding_method (str): decoding mode, e.g.
'fullsentence',
'simultaneous'
lang_model_path (str): lm path.
beam_alpha (float): lm weight.
beam_beta (float): length penalty.
beam_size (int): beam size for search
cutoff_prob (float): for prune.
cutoff_top_n (int): for prune.
num_processes (int):
ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0.
decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here.
num_decoding_left_chunks (int, optional):
number of left chunks for decoding. Defaults to -1.
simulate_streaming (bool, optional): simulate streaming inference. Defaults to False.
Raises:
ValueError: when not support decoding_method.
Returns:
List[List[int]]: transcripts.
"""
batch_size = feats.size(0)
if decoding_method == 'fullsentence':
hyps = self.translate(
feats,
feats_lengths,
beam_size=beam_size,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks,
simulate_streaming=simulate_streaming)
hyps = [hyp.tolist() for hyp in hyps]
else:
raise ValueError(f"Not support decoding method: {decoding_method}")
res = [text_feature.defeaturize(hyp) for hyp in hyps]
return res
class U2STModel(U2STBaseModel):
def __init__(self, configs: dict):
vocab_size, encoder, decoder = U2STModel._init_from_config(configs)
if isinstance(decoder, Tuple):
st_decoder, asr_decoder, ctc = decoder
super().__init__(
vocab_size=vocab_size,
encoder=encoder,
st_decoder=st_decoder,
decoder=asr_decoder,
ctc=ctc,
**configs['model_conf'])
else:
super().__init__(
vocab_size=vocab_size,
encoder=encoder,
st_decoder=decoder,
**configs['model_conf'])
@classmethod
def _init_from_config(cls, configs: dict):
"""init sub module for model.
Args:
configs (dict): config dict.
Raises:
ValueError: raise when using not support encoder type.
Returns:
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
"""
if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'],
configs['cmvn_file_type'])
global_cmvn = GlobalCMVN(
paddle.to_tensor(mean, dtype=paddle.float),
paddle.to_tensor(istd, dtype=paddle.float))
else:
global_cmvn = None
input_dim = configs['input_dim']
vocab_size = configs['output_dim']
assert input_dim != 0, input_dim
assert vocab_size != 0, vocab_size
encoder_type = configs.get('encoder', 'transformer')
logger.info(f"U2 Encoder type: {encoder_type}")
if encoder_type == 'transformer':
encoder = TransformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
elif encoder_type == 'conformer':
encoder = ConformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
else:
raise ValueError(f"not support encoder type:{encoder_type}")
st_decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
asr_weight = configs['model_conf']['asr_weight']
logger.info(f"ASR Joint Training Weight: {asr_weight}")
if asr_weight > 0.:
decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
ctc = CTCDecoder(
odim=vocab_size,
enc_n_units=encoder.output_size(),
blank_id=0,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True) # sum / batch_size
return vocab_size, encoder, (st_decoder, decoder, ctc)
else:
return vocab_size, encoder, st_decoder
@classmethod
def from_config(cls, configs: dict):
"""init model.
Args:
configs (dict): config dict.
Raises:
ValueError: raise when using not support encoder type.
Returns:
nn.Layer: U2STModel
"""
model = cls(configs)
return model
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Args:
dataloader (paddle.io.DataLoader): not used.
config (yacs.config.CfgNode): model configs
checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name
Returns:
DeepSpeech2Model: The model built from pretrained result.
"""
config.defrost()
config.input_dim = dataloader.collate_fn.feature_size
config.output_dim = dataloader.collate_fn.vocab_size
config.freeze()
model = cls.from_config(config)
if checkpoint_path:
infos = checkpoint.load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
return model
class U2STInferModel(U2STModel):
def __init__(self, configs: dict):
super().__init__(configs)
def forward(self,
feats,
feats_lengths,
decoding_chunk_size=-1,
num_decoding_left_chunks=-1,
simulate_streaming=False):
"""export model function
Args:
feats (Tensor): [B, T, D]
feats_lengths (Tensor): [B]
Returns:
List[List[int]]: best path result
"""
return self.translate(
feats,
feats_lengths,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks,
simulate_streaming=simulate_streaming)
......@@ -114,7 +114,8 @@ class ConvBn(nn.Layer):
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
# TODO(Hui Zhang): not support bool multiply
masks = masks.type_as(x)
# masks = masks.type_as(x)
masks = masks.astype(x.dtype)
x = x.multiply(masks)
return x, x_len
......
......@@ -219,11 +219,14 @@ class BaseEncoder(nn.Layer):
xs, pos_emb, _ = self.embed(
xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D)
if subsampling_cache is not None:
cache_size = subsampling_cache.size(1) #T
xs = paddle.cat((subsampling_cache, xs), dim=1)
else:
cache_size = 0
# only used when using `RelPositionMultiHeadedAttention`
pos_emb = self.embed.position_encoding(
offset=offset - cache_size, size=xs.size(1))
......@@ -237,7 +240,7 @@ class BaseEncoder(nn.Layer):
# Real mask for transformer/conformer layers
masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool)
masks = masks.unsqueeze(1) #[B=1, C=1, T]
masks = masks.unsqueeze(1) #[B=1, L'=1, T]
r_elayers_output_cache = []
r_conformer_cnn_cache = []
for i, layer in enumerate(self.encoders):
......
......@@ -309,6 +309,6 @@ class RNNStack(nn.Layer):
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1]
# TODO(Hui Zhang): not support bool multiply
masks = masks.type_as(x)
masks = masks.astype(x.dtype)
x = x.multiply(masks)
return x, x_len
......@@ -92,7 +92,7 @@ class Conv2dSubsampling4(BaseSubsampling):
dropout_rate: float,
pos_enc_class: nn.Layer=PositionalEncoding):
"""Construct an Conv2dSubsampling4 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
......@@ -143,7 +143,7 @@ class Conv2dSubsampling6(BaseSubsampling):
dropout_rate: float,
pos_enc_class: nn.Layer=PositionalEncoding):
"""Construct an Conv2dSubsampling6 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
......@@ -196,7 +196,7 @@ class Conv2dSubsampling8(BaseSubsampling):
dropout_rate: float,
pos_enc_class: nn.Layer=PositionalEncoding):
"""Construct an Conv2dSubsampling8 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
......
......@@ -27,6 +27,9 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
def __init__(self, clip_norm):
super().__init__(clip_norm)
def __repr__(self):
return f"{self.__class__.__name__}(global_clip_norm={self.clip_norm})"
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from typing import Text
from paddle.optimizer import Optimizer
from paddle.regularizer import L2Decay
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.dynamic_import import instance_class
from deepspeech.utils.log import Log
__all__ = ["OptimizerFactory"]
logger = Log(__name__).getlog()
OPTIMIZER_DICT = {
"sgd": "paddle.optimizer:SGD",
"momentum": "paddle.optimizer:Momentum",
"adadelta": "paddle.optimizer:Adadelta",
"adam": "paddle.optimizer:Adam",
"adamw": "paddle.optimizer:AdamW",
}
def register_optimizer(cls):
"""Register optimizer."""
alias = cls.__name__.lower()
OPTIMIZER_DICT[cls.__name__.lower()] = cls.__module__ + ":" + cls.__name__
return cls
def dynamic_import_optimizer(module):
"""Import Optimizer class dynamically.
Args:
module (str): module_name:class_name or alias in `OPTIMIZER_DICT`
Returns:
type: Optimizer class
"""
module_class = dynamic_import(module, OPTIMIZER_DICT)
assert issubclass(module_class,
Optimizer), f"{module} does not implement Optimizer"
return module_class
class OptimizerFactory():
@classmethod
def from_args(cls, name: str, args: Dict[Text, Any]):
assert "parameters" in args, "parameters not in args."
assert "learning_rate" in args, "learning_rate not in args."
grad_clip = ClipGradByGlobalNormWithLog(
args['grad_clip']) if "grad_clip" in args else None
weight_decay = L2Decay(
args['weight_decay']) if "weight_decay" in args else None
module_class = dynamic_import_optimizer(name.lower())
if weight_decay:
logger.info(f'WeightDecay: {weight_decay}')
if grad_clip:
logger.info(f'GradClip: {grad_clip}')
logger.info(
f"Optimizer: {module_class.__name__} {args['learning_rate']}")
args.update({"grad_clip": grad_clip, "weight_decay": weight_decay})
return instance_class(module_class, args)
......@@ -11,18 +11,53 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from typing import Text
from typing import Union
from paddle.optimizer.lr import LRScheduler
from typeguard import check_argument_types
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.dynamic_import import instance_class
from deepspeech.utils.log import Log
__all__ = ["WarmupLR"]
__all__ = ["WarmupLR", "LRSchedulerFactory"]
logger = Log(__name__).getlog()
SCHEDULER_DICT = {
"noam": "paddle.optimizer.lr:NoamDecay",
"expdecaylr": "paddle.optimizer.lr:ExponentialDecay",
"piecewisedecay": "paddle.optimizer.lr:PiecewiseDecay",
}
def register_scheduler(cls):
"""Register scheduler."""
alias = cls.__name__.lower()
SCHEDULER_DICT[cls.__name__.lower()] = cls.__module__ + ":" + cls.__name__
return cls
def dynamic_import_scheduler(module):
"""Import Scheduler class dynamically.
Args:
module (str): module_name:class_name or alias in `SCHEDULER_DICT`
Returns:
type: Scheduler class
"""
module_class = dynamic_import(module, SCHEDULER_DICT)
assert issubclass(module_class,
LRScheduler), f"{module} does not implement LRScheduler"
return module_class
@register_scheduler
class WarmupLR(LRScheduler):
"""The WarmupLR scheduler
This scheduler is almost same as NoamLR Scheduler except for following
......@@ -40,7 +75,8 @@ class WarmupLR(LRScheduler):
warmup_steps: Union[int, float]=25000,
learning_rate=1.0,
last_epoch=-1,
verbose=False):
verbose=False,
**kwargs):
assert check_argument_types()
self.warmup_steps = warmup_steps
super().__init__(learning_rate, last_epoch, verbose)
......@@ -64,3 +100,10 @@ class WarmupLR(LRScheduler):
None
'''
self.step(epoch=step)
class LRSchedulerFactory():
@classmethod
def from_args(cls, name: str, args: Dict[Text, Any]):
module_class = dynamic_import_scheduler(name.lower())
return instance_class(module_class, args)
......@@ -18,8 +18,8 @@ import paddle
from paddle import distributed as dist
from tensorboardX import SummaryWriter
from deepspeech.utils import checkpoint
from deepspeech.utils import mp_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
__all__ = ["Trainer"]
......@@ -29,37 +29,37 @@ logger = Log(__name__).getlog()
class Trainer():
"""
An experiment template in order to structure the training code and take
care of saving, loading, logging, visualization stuffs. It's intended to
be flexible and simple.
So it only handles output directory (create directory for the output,
create a checkpoint directory, dump the config in use and create
An experiment template in order to structure the training code and take
care of saving, loading, logging, visualization stuffs. It's intended to
be flexible and simple.
So it only handles output directory (create directory for the output,
create a checkpoint directory, dump the config in use and create
visualizer and logger) in a standard way without enforcing any
input-output protocols to the model and dataloader. It leaves the main
part for the user to implement their own (setup the model, criterion,
optimizer, define a training step, define a validation function and
input-output protocols to the model and dataloader. It leaves the main
part for the user to implement their own (setup the model, criterion,
optimizer, define a training step, define a validation function and
customize all the text and visual logs).
It does not save too much boilerplate code. The users still have to write
the forward/backward/update mannually, but they are free to add
It does not save too much boilerplate code. The users still have to write
the forward/backward/update mannually, but they are free to add
non-standard behaviors if needed.
We have some conventions to follow.
1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
``valid_loader``, ``config`` and ``args`` attributes.
2. The config should have a ``training`` field, which has
``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
used as the trigger to invoke validation, checkpointing and stop of the
2. The config should have a ``training`` field, which has
``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
used as the trigger to invoke validation, checkpointing and stop of the
experiment.
3. There are four methods, namely ``train_batch``, ``valid``,
3. There are four methods, namely ``train_batch``, ``valid``,
``setup_model`` and ``setup_dataloader`` that should be implemented.
Feel free to add/overwrite other methods and standalone functions if you
Feel free to add/overwrite other methods and standalone functions if you
need.
Parameters
----------
config: yacs.config.CfgNode
The configuration used for the experiment.
args: argparse.Namespace
The parsed command line arguments.
Examples
......@@ -68,16 +68,16 @@ class Trainer():
>>> exp = Trainer(config, args)
>>> exp.setup()
>>> exp.run()
>>>
>>>
>>> config = get_cfg_defaults()
>>> parser = default_argument_parser()
>>> args = parser.parse_args()
>>> if args.config:
>>> if args.config:
>>> config.merge_from_file(args.config)
>>> if args.opts:
>>> config.merge_from_list(args.opts)
>>> config.freeze()
>>>
>>>
>>> if args.nprocs > 1 and args.device == "gpu":
>>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
>>> else:
......@@ -114,7 +114,7 @@ class Trainer():
@property
def parallel(self):
"""A flag indicating whether the experiment should run with
"""A flag indicating whether the experiment should run with
multiprocessing.
"""
return self.args.device == "gpu" and self.args.nprocs > 1
......@@ -139,19 +139,19 @@ class Trainer():
"epoch": self.epoch,
"lr": self.optimizer.get_lr()
})
checkpoint.save_parameters(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model,
self.optimizer, infos)
self.checkpoint.add_checkpoint(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model,
self.optimizer, infos)
def resume_or_scratch(self):
"""Resume from latest checkpoint at checkpoints in the output
"""Resume from latest checkpoint at checkpoints in the output
directory or load a specified checkpoint.
If ``args.checkpoint_path`` is not None, load the checkpoint, else
resume training.
"""
scratch = None
infos = checkpoint.load_parameters(
infos = self.checkpoint.load_latest_parameters(
self.model,
self.optimizer,
checkpoint_dir=self.checkpoint_dir,
......@@ -180,9 +180,8 @@ class Trainer():
from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
self.save(tag='init', infos=None)
self.lr_scheduler.step(self.epoch)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
......@@ -254,7 +253,7 @@ class Trainer():
def setup_checkpointer(self):
"""Create a directory used to save checkpoints into.
It is "checkpoints" inside the output directory.
"""
# checkpoint dir
......@@ -263,6 +262,10 @@ class Trainer():
self.checkpoint_dir = checkpoint_dir
self.checkpoint = Checkpoint(
kbest_n=self.config.training.checkpoint.kbest_n,
latest_n=self.config.training.checkpoint.latest_n)
@mp_tools.rank_zero_only
def destory(self):
"""Close visualizer to avoid hanging after training"""
......@@ -273,13 +276,13 @@ class Trainer():
@mp_tools.rank_zero_only
def setup_visualizer(self):
"""Initialize a visualizer to log the experiment.
The visual log is saved in the output directory.
Notes
------
Only the main process has a visualizer with it. Use multiple
visualizers in multiprocess to write to a same log file may cause
Only the main process has a visualizer with it. Use multiple
visualizers in multiprocess to write to a same log file may cause
unexpected behaviors.
"""
# visualizer
......@@ -288,9 +291,9 @@ class Trainer():
@mp_tools.rank_zero_only
def dump_config(self):
"""Save the configuration used for this experiment.
It is saved in to ``config.yaml`` in the output directory at the
"""Save the configuration used for this experiment.
It is saved in to ``config.yaml`` in the output directory at the
beginning of the experiment.
"""
with open(self.output_dir / "config.yaml", 'wt') as f:
......@@ -308,13 +311,13 @@ class Trainer():
raise NotImplementedError("valid should be implemented.")
def setup_model(self):
"""Setup model, criterion and optimizer, etc. A subclass should
"""Setup model, criterion and optimizer, etc. A subclass should
implement this method.
"""
raise NotImplementedError("setup_model should be implemented.")
def setup_dataloader(self):
"""Setup training dataloader and validation dataloader. A subclass
"""Setup training dataloader and validation dataloader. A subclass
should implement this method.
"""
raise NotImplementedError("setup_dataloader should be implemented.")
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module provides functions to calculate bleu score in different level.
e.g. wer for word-level, cer for char-level.
"""
import sacrebleu
__all__ = ['bleu', 'char_bleu']
def bleu(hypothesis, reference):
"""Calculate BLEU. BLEU compares reference text and
hypothesis text in word-level using scarebleu.
:param reference: The reference sentences.
:type reference: list[list[str]]
:param hypothesis: The hypothesis sentence.
:type hypothesis: list[str]
:raises ValueError: If the reference length is zero.
"""
return sacrebleu.corpus_bleu(hypothesis, reference)
def char_bleu(hypothesis, reference):
"""Calculate BLEU. BLEU compares reference text and
hypothesis text in char-level using scarebleu.
:param reference: The reference sentences.
:type reference: list[list[str]]
:param hypothesis: The hypothesis sentence.
:type hypothesis: list[str]
:raises ValueError: If the reference number is zero.
"""
hypothesis = [' '.join(list(hyp.replace(' ', ''))) for hyp in hypothesis]
reference = [[' '.join(list(ref_i.replace(' ', ''))) for ref_i in ref]
for ref in reference]
return sacrebleu.corpus_bleu(hypothesis, reference)
......@@ -11,9 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import json
import os
import re
from pathlib import Path
from typing import Text
from typing import Union
import paddle
......@@ -25,128 +28,271 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ["load_parameters", "save_parameters"]
def _load_latest_checkpoint(checkpoint_dir: str) -> int:
"""Get the iteration number corresponding to the latest saved checkpoint.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
Returns:
int: the latest iteration number. -1 for no checkpoint to load.
"""
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
if not os.path.isfile(checkpoint_record):
return -1
# Fetch the latest checkpoint index.
with open(checkpoint_record, "rt") as handle:
latest_checkpoint = handle.readlines()[-1].strip()
iteration = int(latest_checkpoint.split(":")[-1])
return iteration
def _save_record(checkpoint_dir: str, iteration: int):
"""Save the iteration number of the latest model to be checkpoint record.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
Returns:
None
"""
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
# Update the latest checkpoint index.
with open(checkpoint_record, "a+") as handle:
handle.write("model_checkpoint_path:{}\n".format(iteration))
def load_parameters(model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None):
"""Load a specific model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
configs = {}
if checkpoint_path is not None:
tag = os.path.basename(checkpoint_path).split(":")[-1]
elif checkpoint_dir is not None:
iteration = _load_latest_checkpoint(checkpoint_dir)
if iteration == -1:
return configs
checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration))
else:
raise ValueError(
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
)
rank = dist.get_rank()
params_path = checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
model.set_state_dict(model_dict)
logger.info("Rank {}: loaded model from {}".format(rank, params_path))
optimizer_path = checkpoint_path + ".pdopt"
if optimizer and os.path.isfile(optimizer_path):
optimizer_dict = paddle.load(optimizer_path)
optimizer.set_state_dict(optimizer_dict)
logger.info("Rank {}: loaded optimizer state from {}".format(
rank, optimizer_path))
info_path = re.sub('.pdparams$', '.json', params_path)
if os.path.exists(info_path):
with open(info_path, 'r') as fin:
configs = json.load(fin)
return configs
@mp_tools.rank_zero_only
def save_parameters(checkpoint_dir: str,
tag_or_iteration: Union[int, str],
model: paddle.nn.Layer,
optimizer: Optimizer=None,
infos: dict=None):
"""Checkpoint the latest trained model parameters.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
tag_or_iteration (int or str): the latest iteration(step or epoch) number.
model (Layer): model to be checkpointed.
optimizer (Optimizer, optional): optimizer to be checkpointed.
Defaults to None.
infos (dict or None): any info you want to save.
Returns:
None
"""
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(tag_or_iteration))
model_dict = model.state_dict()
params_path = checkpoint_path + ".pdparams"
paddle.save(model_dict, params_path)
logger.info("Saved model to {}".format(params_path))
if optimizer:
opt_dict = optimizer.state_dict()
__all__ = ["Checkpoint"]
class Checkpoint():
def __init__(self, kbest_n: int=5, latest_n: int=1):
self.best_records: Mapping[Path, float] = {}
self.latest_records = []
self.kbest_n = kbest_n
self.latest_n = latest_n
self._save_all = (kbest_n == -1)
def add_checkpoint(self,
checkpoint_dir,
tag_or_iteration: Union[int, Text],
model: paddle.nn.Layer,
optimizer: Optimizer=None,
infos: dict=None,
metric_type="val_loss"):
"""Save checkpoint in best_n and latest_n.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
tag_or_iteration (int or str): the latest iteration(step or epoch) number or tag.
model (Layer): model to be checkpointed.
optimizer (Optimizer, optional): optimizer to be checkpointed.
infos (dict or None)): any info you want to save.
metric_type (str, optional): metric type. Defaults to "val_loss".
"""
if (metric_type not in infos.keys()):
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
return
#save best
if self._should_save_best(infos[metric_type]):
self._save_best_checkpoint_and_update(
infos[metric_type], checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
#save latest
self._save_latest_checkpoint_and_update(
checkpoint_dir, tag_or_iteration, model, optimizer, infos)
if isinstance(tag_or_iteration, int):
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
def load_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None,
record_file="checkpoint_latest"):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
record_file "checkpoint_latest" or "checkpoint_best"
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
configs = {}
if checkpoint_path is not None:
pass
elif checkpoint_dir is not None and record_file is not None:
# load checkpint from record file
checkpoint_record = os.path.join(checkpoint_dir, record_file)
iteration = self._load_checkpoint_idx(checkpoint_record)
if iteration == -1:
return configs
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(iteration))
else:
raise ValueError(
"At least one of 'checkpoint_path' or 'checkpoint_dir' should be specified!"
)
rank = dist.get_rank()
params_path = checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
model.set_state_dict(model_dict)
logger.info("Rank {}: loaded model from {}".format(rank, params_path))
optimizer_path = checkpoint_path + ".pdopt"
paddle.save(opt_dict, optimizer_path)
logger.info("Saved optimzier state to {}".format(optimizer_path))
if optimizer and os.path.isfile(optimizer_path):
optimizer_dict = paddle.load(optimizer_path)
optimizer.set_state_dict(optimizer_dict)
logger.info("Rank {}: loaded optimizer state from {}".format(
rank, optimizer_path))
info_path = re.sub('.pdparams$', '.json', params_path)
if os.path.exists(info_path):
with open(info_path, 'r') as fin:
configs = json.load(fin)
return configs
def load_latest_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return self.load_parameters(model, optimizer, checkpoint_dir,
checkpoint_path, "checkpoint_latest")
def load_best_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return self.load_parameters(model, optimizer, checkpoint_dir,
checkpoint_path, "checkpoint_best")
def _should_save_best(self, metric: float) -> bool:
if not self._best_full():
return True
# already full
worst_record_path = max(self.best_records, key=self.best_records.get)
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
worst_metric = self.best_records[worst_record_path]
return metric < worst_metric
def _best_full(self):
return (not self._save_all) and len(self.best_records) == self.kbest_n
def _latest_full(self):
return len(self.latest_records) == self.latest_n
def _save_best_checkpoint_and_update(self, metric, checkpoint_dir,
tag_or_iteration, model, optimizer,
infos):
# remove the worst
if self._best_full():
worst_record_path = max(self.best_records,
key=self.best_records.get)
self.best_records.pop(worst_record_path)
if (worst_record_path not in self.latest_records):
logger.info(
"remove the worst checkpoint: {}".format(worst_record_path))
self._del_checkpoint(checkpoint_dir, worst_record_path)
# add the new one
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
self.best_records[tag_or_iteration] = metric
def _save_latest_checkpoint_and_update(
self, checkpoint_dir, tag_or_iteration, model, optimizer, infos):
# remove the old
if self._latest_full():
to_del_fn = self.latest_records.pop(0)
if (to_del_fn not in self.best_records.keys()):
logger.info(
"remove the latest checkpoint: {}".format(to_del_fn))
self._del_checkpoint(checkpoint_dir, to_del_fn)
self.latest_records.append(tag_or_iteration)
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
def _del_checkpoint(self, checkpoint_dir, tag_or_iteration):
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(tag_or_iteration))
for filename in glob.glob(checkpoint_path + ".*"):
os.remove(filename)
logger.info("delete file: {}".format(filename))
def _load_checkpoint_idx(self, checkpoint_record: str) -> int:
"""Get the iteration number corresponding to the latest saved checkpoint.
Args:
checkpoint_path (str): the saved path of checkpoint.
Returns:
int: the latest iteration number. -1 for no checkpoint to load.
"""
if not os.path.isfile(checkpoint_record):
return -1
# Fetch the latest checkpoint index.
with open(checkpoint_record, "rt") as handle:
latest_checkpoint = handle.readlines()[-1].strip()
iteration = int(latest_checkpoint.split(":")[-1])
return iteration
def _save_checkpoint_record(self, checkpoint_dir: str, iteration: int):
"""Save the iteration number of the latest model to be checkpoint record.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
Returns:
None
"""
checkpoint_record_latest = os.path.join(checkpoint_dir,
"checkpoint_latest")
checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best")
with open(checkpoint_record_best, "w") as handle:
for i in self.best_records.keys():
handle.write("model_checkpoint_path:{}\n".format(i))
with open(checkpoint_record_latest, "w") as handle:
for i in self.latest_records:
handle.write("model_checkpoint_path:{}\n".format(i))
@mp_tools.rank_zero_only
def _save_parameters(self,
checkpoint_dir: str,
tag_or_iteration: Union[int, str],
model: paddle.nn.Layer,
optimizer: Optimizer=None,
infos: dict=None):
"""Checkpoint the latest trained model parameters.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
tag_or_iteration (int or str): the latest iteration(step or epoch) number.
model (Layer): model to be checkpointed.
optimizer (Optimizer, optional): optimizer to be checkpointed.
Defaults to None.
infos (dict or None): any info you want to save.
Returns:
None
"""
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(tag_or_iteration))
model_dict = model.state_dict()
params_path = checkpoint_path + ".pdparams"
paddle.save(model_dict, params_path)
logger.info("Saved model to {}".format(params_path))
info_path = re.sub('.pdparams$', '.json', params_path)
infos = {} if infos is None else infos
with open(info_path, 'w') as fout:
data = json.dumps(infos)
fout.write(data)
if optimizer:
opt_dict = optimizer.state_dict()
optimizer_path = checkpoint_path + ".pdopt"
paddle.save(opt_dict, optimizer_path)
logger.info("Saved optimzier state to {}".format(optimizer_path))
if isinstance(tag_or_iteration, int):
_save_record(checkpoint_dir, tag_or_iteration)
info_path = re.sub('.pdparams$', '.json', params_path)
infos = {} if infos is None else infos
with open(info_path, 'w') as fout:
data = json.dumps(infos)
fout.write(data)
......@@ -38,21 +38,23 @@ def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
# add non-blank into new_hyp
if hyp[cur] != blank_id:
new_hyp.append(hyp[cur])
# skip repeat label
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp
def insert_blank(label: np.ndarray, blank_id: int=0):
def insert_blank(label: np.ndarray, blank_id: int=0) -> np.ndarray:
"""Insert blank token between every two label token.
"abcdefg" -> "-a-b-c-d-e-f-g-"
Args:
label ([np.ndarray]): label ids, (L).
label ([np.ndarray]): label ids, List[int], (L).
blank_id (int, optional): blank id. Defaults to 0.
Returns:
......@@ -61,13 +63,13 @@ def insert_blank(label: np.ndarray, blank_id: int=0):
label = np.expand_dims(label, 1) #[L, 1]
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
label = np.concatenate([blanks, label], axis=1) #[L, 2]
label = label.reshape(-1) #[2L]
label = np.append(label, label[0]) #[2L + 1]
label = label.reshape(-1) #[2L], -l-l-l
label = np.append(label, label[0]) #[2L + 1], -l-l-l-
return label
def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
blank_id=0) -> list:
blank_id=0) -> List[int]:
"""ctc forced alignment.
https://distill.pub/2017/ctc/
......@@ -77,23 +79,25 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
y (paddle.Tensor): label id sequence tensor, 1d tensor (L)
blank_id (int): blank symbol index
Returns:
paddle.Tensor: best alignment result, (T).
List[int]: best alignment result, (T).
"""
y_insert_blank = insert_blank(y, blank_id)
y_insert_blank = insert_blank(y, blank_id) #(2L+1)
log_alpha = paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1)
log_alpha = log_alpha - float('inf') # log of zero
# TODO(Hui Zhang): zeros not support paddle.int16
state_path = (paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int16) - 1
) # state path
(ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int32) - 1
) # state path, Tuple((T, 2L+1))
# init start state
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # Sb
log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # Snb
# TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb
log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb
for t in range(1, ctc_probs.size(0)):
for s in range(len(y_insert_blank)):
for t in range(1, ctc_probs.size(0)): # T
for s in range(len(y_insert_blank)): # 2L+1
if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
s] == y_insert_blank[s - 2]:
candidates = paddle.to_tensor(
......@@ -106,11 +110,13 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
log_alpha[t - 1, s - 2],
])
prev_state = [s, s - 1, s - 2]
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][
y_insert_blank[s]]
# TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int(
y_insert_blank[s])]
state_path[t, s] = prev_state[paddle.argmax(candidates)]
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int16)
# TODO(Hui Zhang): zeros not support paddle.int16
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int32)
candidates = paddle.to_tensor([
log_alpha[-1, len(y_insert_blank) - 1], # Sb
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
from typing import Any
from typing import Dict
from typing import List
from typing import Text
from deepspeech.utils.log import Log
from deepspeech.utils.tensor_utils import has_tensor
logger = Log(__name__).getlog()
__all__ = ["dynamic_import", "instance_class"]
def dynamic_import(import_path, alias=dict()):
"""dynamic import module and class
:param str import_path: syntax 'module_name:class_name'
e.g., 'deepspeech.models.u2:U2Model'
:param dict alias: shortcut for registered class
:return: imported class
"""
if import_path not in alias and ":" not in import_path:
raise ValueError("import_path should be one of {} or "
'include ":", e.g. "deepspeech.models.u2:U2Model" : '
"{}".format(set(alias), import_path))
if ":" not in import_path:
import_path = alias[import_path]
module_name, objname = import_path.split(":")
m = importlib.import_module(module_name)
return getattr(m, objname)
def filter_valid_args(args: Dict[Text, Any], valid_keys: List[Text]):
# filter by `valid_keys` and filter `val` is not None
new_args = {
key: val
for key, val in args.items() if (key in valid_keys and val is not None)
}
return new_args
def filter_out_tenosr(args: Dict[Text, Any]):
return {key: val for key, val in args.items() if not has_tensor(val)}
def instance_class(module_class, args: Dict[Text, Any]):
valid_keys = inspect.signature(module_class).parameters.keys()
new_args = filter_valid_args(args, valid_keys)
logger.info(
f"Instance: {module_class.__name__} {filter_out_tenosr(new_args)}.")
return module_class(**new_args)
......@@ -17,6 +17,8 @@ import os
import socket
import sys
from paddle import inference
FORMAT_STR = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
DATE_FMT_STR = '%Y/%m/%d %H:%M:%S'
......@@ -146,3 +148,35 @@ class Log():
def getlog(self):
return self.logger
class Autolog:
def __init__(self,
batch_size,
model_name="DeepSpeech",
model_precision="fp32"):
import auto_log
pid = os.getpid()
if (os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''):
gpu_id = int(os.environ['CUDA_VISIBLE_DEVICES'].split(',')[0])
infer_config = inference.Config()
infer_config.enable_use_gpu(100, gpu_id)
else:
gpu_id = None
infer_config = inference.Config()
autolog = auto_log.AutoLogger(
model_name=model_name,
model_precision=model_precision,
batch_size=batch_size,
data_shape="dynamic",
save_path="./output/auto_log.lpg",
inference_config=infer_config,
pids=pid,
process_name=None,
gpu_ids=gpu_id,
time_keys=['preprocess_time', 'inference_time', 'postprocess_time'],
warmup=0)
self.autolog = autolog
def getlog(self):
return self.autolog
......@@ -48,9 +48,9 @@ def warm_up_test(audio_process_handler,
rng = random.Random(random_seed)
samples = rng.sample(manifest, num_test_cases)
for idx, sample in enumerate(samples):
print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
print("Warm-up Test Case %d: %s" % (idx, sample['feat']))
start_time = time.time()
transcript = audio_process_handler(sample['audio_filepath'])
transcript = audio_process_handler(sample['feat'])
finish_time = time.time()
print("Response Time: %f, Transcript: %s" %
(finish_time - start_time, transcript))
......
......@@ -19,11 +19,25 @@ import paddle
from deepspeech.utils.log import Log
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy"]
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"]
logger = Log(__name__).getlog()
def has_tensor(val):
if isinstance(val, (list, tuple)):
for item in val:
if has_tensor(item):
return True
elif isinstance(val, dict):
for k, v in val.items():
print(k)
if has_tensor(v):
return True
else:
return paddle.is_tensor(val)
def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor:
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import List
from typing import Text
import textgrid
def segment_alignment(alignment: List[int], blank_id=0) -> List[List[int]]:
"""segment ctc alignment ids by continuous blank and repeat label.
Args:
alignment (List[int]): ctc alignment id sequence.
e.g. [0, 0, 0, 1, 1, 1, 2, 0, 0, 3]
blank_id (int, optional): blank id. Defaults to 0.
Returns:
List[List[int]]: token align, segment aligment id sequence.
e.g. [[0, 0, 0, 1, 1, 1], [2], [0, 0, 3]]
"""
# convert alignment to a praat format, which is a doing phonetics
# by computer and helps analyzing alignment
align_segs = []
# get frames level duration for each token
start = 0
end = 0
while end < len(alignment):
while end < len(alignment) and alignment[end] == blank_id: # blank
end += 1
if end == len(alignment):
align_segs[-1].extend(alignment[start:])
break
end += 1
while end < len(alignment) and alignment[end - 1] == alignment[
end]: # repeat label
end += 1
align_segs.append(alignment[start:end])
start = end
return align_segs
def align_to_tierformat(align_segs: List[List[int]],
subsample: int,
token_dict: Dict[int, Text],
blank_id=0) -> List[Text]:
"""Generate textgrid.Interval format from alignment segmentations.
Args:
align_segs (List[List[int]]): segmented ctc alignment ids.
subsample (int): 25ms frame_length, 10ms hop_length, 1/subsample
token_dict (Dict[int, Text]): int -> str map.
Returns:
List[Text]: list of textgrid.Interval text, str(start, end, text).
"""
hop_length = 10 # ms
second_ms = 1000 # ms
frame_per_second = second_ms / hop_length # 25ms frame_length, 10ms hop_length
second_per_frame = 1.0 / frame_per_second
begin = 0
duration = 0
tierformat = []
for idx, tokens in enumerate(align_segs):
token_len = len(tokens)
token = tokens[-1]
# time duration in second
duration = token_len * subsample * second_per_frame
if idx < len(align_segs) - 1:
print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}")
tierformat.append(
f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n")
else:
for i in tokens:
if i != blank_id:
token = i
break
print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}")
tierformat.append(
f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n")
begin = begin + duration
return tierformat
def generate_textgrid(maxtime: float,
intervals: List[Text],
output: Text,
name: Text='ali') -> None:
"""Create alignment textgrid file.
Args:
maxtime (float): audio duartion.
intervals (List[Text]): ctc output alignment. e.g. "start-time end-time word" per item.
output (Text): textgrid filepath.
name (Text, optional): tier or layer name. Defaults to 'ali'.
"""
# Download Praat: https://www.fon.hum.uva.nl/praat/
avg_interval = maxtime / (len(intervals) + 1)
print(f"average second/token: {avg_interval}")
margin = 0.0001
tg = textgrid.TextGrid(maxTime=maxtime)
tier = textgrid.IntervalTier(name=name, maxTime=maxtime)
i = 0
for dur in intervals:
s, e, text = dur.split()
tier.add(minTime=float(s) + margin, maxTime=float(e), mark=text)
tg.append(tier)
tg.write(output)
print("successfully generator textgrid {}.".format(output))
......@@ -79,3 +79,22 @@ def log_add(args: List[int]) -> float:
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max) for a in args))
return a_max + lsp
def get_subsample(config):
"""Subsample rate from config.
Args:
config (yacs.config.CfgNode): yaml config
Returns:
int: subsample rate.
"""
input_layer = config["model"]["encoder_conf"]["input_layer"]
assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
if input_layer == "conv2d":
return 4
elif input_layer == "conv2d6":
return 6
elif input_layer == "conv2d8":
return 8
......@@ -4,7 +4,7 @@ export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
# ASR
* s0 for deepspeech2
* s0 for deepspeech2 offline
* s1 for u2
......@@ -2,9 +2,10 @@
## Deepspeech2
| Model | release | Config | Test set | Loss | CER |
| --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 |
| DeepSpeech2 | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 |
| DeepSpeech2 | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 |
| DeepSpeech2 | 1.8.5 | - | test | - | 0.080447 |
| Model | Params | Release | Config | Test set | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug + new datapipe | test | 6.396368026733398 | 0.068382,0.073507 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 |
| DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 |
| DeepSpeech2 | 58.4M | 1.8.5 | - | test | - | 0.080447 |
......@@ -3,31 +3,36 @@ data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
mean_std_filepath: data/mean_std.json
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
batch_size: 64 # one gpu
min_input_len: 0.0
max_input_len: 27.0 # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
collator:
batch_size: 64 # one gpu
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear
target_sample_rate: 16000
max_freq: None
n_fft: None
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
delta_delta: False
dither: 1.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
random_seed: 0
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
num_workers: 2
model:
num_conv_layers: 2
......@@ -43,6 +48,9 @@ training:
weight_decay: 1e-06
global_grad_clip: 3.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 128
......
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.0
max_input_len: 27.0 # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
collator:
batch_size: 32 # one gpu
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear #linear, mfcc, fbank
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
model:
num_conv_layers: 2
num_rnn_layers: 3
rnn_layer_size: 1024
rnn_direction: forward # [forward, bidirect]
num_fc_layers: 1
fc_layers_size_list: 512,
use_gru: True
training:
n_epoch: 50
lr: 2e-3
lr_decay: 0.83 # 0.83
weight_decay: 1e-06
global_grad_clip: 3.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 32
error_rate_type: cer
decoding_method: ctc_beam_search
lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm
alpha: 1.9
beta: 5.0
beam_size: 300
cutoff_prob: 0.99
cutoff_top_n: 40
num_proc_bsearch: 10
#! /usr/bin/env bash
#!/bin/bash
source path.sh
......
#! /usr/bin/env bash
#!/bin/bash
stage=-1
stop_stage=100
......
#! /usr/bin/env bash
#!/bin/bash
. ${MAIN_ROOT}/utils/utility.sh
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 3 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path"
if [ $# != 4 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path model_type"
exit -1
fi
......@@ -11,9 +11,10 @@ echo "using $ngpu gpus..."
config_path=$1
ckpt_path_prefix=$2
jit_model_export_path=$3
model_type=$4
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
......@@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
--nproc ${ngpu} \
--config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \
--export_path ${jit_model_export_path}
--export_path ${jit_model_export_path} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in export!"
......
#! /usr/bin/env bash
#!/bin/bash
# TODO: replace the model with a mandarin model
if [[ $# != 1 ]];then
......@@ -15,10 +15,10 @@ if [ $? -ne 0 ]; then
fi
# download well-trained model
bash local/download_model.sh
if [ $? -ne 0 ]; then
exit 1
fi
#bash local/download_model.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
# start demo server
CUDA_VISIBLE_DEVICES=0 \
......@@ -29,7 +29,7 @@ python3 -u ${BIN_DIR}/deploy/server.py \
--host_ip="localhost" \
--host_port=8086 \
--speech_save_dir="demo_cache" \
--checkpoint_path ${1}
--checkpoint_path ${1}
if [ $? -ne 0 ]; then
echo "Failed in starting demo server!"
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1
fi
......@@ -9,11 +9,12 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
model_type=$3
# download language model
bash local/download_lm_ch.sh
......@@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix}
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
if [ $# != 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type"
exit -1
fi
......@@ -10,9 +10,10 @@ echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
model_type=$3
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
......@@ -22,7 +23,8 @@ python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
--output exp/${ckpt_name} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in training!"
......
#! /usr/bin/env bash
#!/bin/bash
# grid-search for hyper-parameters in language model
python3 -u ${BIN_DIR}/tune.py \
......
export MAIN_ROOT=${PWD}/../../../
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
......
......@@ -2,11 +2,12 @@
set -e
source path.sh
gpus=0
gpus=0,1,2,3
stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=1
model_type=offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
......@@ -21,20 +22,20 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt}
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
./local/avg.sh exp/${ckpt}/checkpoints ${avg_num}
avg.sh exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type}|| exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES=${gpus} ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}
fi
......@@ -2,15 +2,26 @@
## Conformer
| Model | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- |
| conformer | conf/conformer.yaml | spec_aug + shift | test | attention | - | 0.059858 |
| conformer | conf/conformer.yaml | spec_aug + shift | test | ctc_greedy_search | - | 0.062311 |
| conformer | conf/conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | - | 0.062196 |
| conformer | conf/conformer.yaml | spec_aug + shift | test | attention_rescoring | - | 0.054694 |
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.07M | conf/conformer.yaml | spec_aug + shift | test | attention | - | 0.059858 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug + shift | test | ctc_greedy_search | - | 0.062311 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | - | 0.062196 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug + shift | test | attention_rescoring | - | 0.054694 |
## Chunk Conformer
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug + shift | test | attention | 16, -1 | - | 0.061939 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug + shift | test | ctc_greedy_search | 16, -1 | - | 0.070806 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | 16, -1 | - | 0.070739 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug + shift | test | attention_rescoring | 16, -1 | - | 0.059400 |
## Transformer
| Model | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | ---|
| transformer | conf/transformer.yaml | spec_aug + shift | test | attention | - | - |
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | ---|
| transformer | - | conf/transformer.yaml | spec_aug + shift | test | attention | - | - |
......@@ -3,17 +3,20 @@ data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
vocab_filepath: data/vocab.txt
unit_type: 'char'
spm_model_prefix: ''
augmentation_config: conf/augmentation.json
batch_size: 32
min_input_len: 0.5
max_input_len: 20.0 # second
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'char'
spm_model_prefix: ''
augmentation_config: conf/augmentation.json
batch_size: 32
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
......@@ -30,7 +33,7 @@ data:
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
num_workers: 2
# network architecture
......@@ -78,7 +81,7 @@ model:
training:
n_epoch: 180
n_epoch: 240
accum_grad: 4
global_grad_clip: 5.0
optim: adam
......@@ -90,6 +93,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
......
......@@ -3,17 +3,20 @@ data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
vocab_filepath: data/vocab.txt
unit_type: 'char'
spm_model_prefix: ''
augmentation_config: conf/augmentation.json
batch_size: 64
min_input_len: 0.5
max_input_len: 20.0 # second
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'char'
spm_model_prefix: ''
augmentation_config: conf/augmentation.json
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
......@@ -32,7 +35,6 @@ data:
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
......@@ -86,6 +88,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
......
#!/bin/bash
# To be run from one directory above this script.
. ./path.sh
text=data/local/lm/text
lexicon=data/local/dict/lexicon.txt
for f in "$text" "$lexicon"; do
[ ! -f $x ] && echo "$0: No such file $f" && exit 1;
done
# Check SRILM tools
if ! which ngram-count > /dev/null; then
echo "srilm tools are not found, please download it and install it from: "
echo "http://www.speech.sri.com/projects/srilm/download.html"
echo "Then add the tools to your PATH"
exit 1
fi
# This script takes no arguments. It assumes you have already run
# aishell_data_prep.sh.
# It takes as input the files
# data/local/lm/text
# data/local/dict/lexicon.txt
dir=data/local/lm
mkdir -p $dir
cleantext=$dir/text.no_oov
cat $text | awk -v lex=$lexicon 'BEGIN{while((getline<lex) >0){ seen[$1]=1; } }
{for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf("<SPOKEN_NOISE> ");} } printf("\n");}' \
> $cleantext || exit 1;
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort | uniq -c | \
sort -nr > $dir/word.counts || exit 1;
# Get counts from acoustic training transcripts, and add one-count
# for each word in the lexicon (but not silence, we don't want it
# in the LM-- we'll add it optionally later).
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \
cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \
sort | uniq -c | sort -nr > $dir/unigram.counts || exit 1;
cat $dir/unigram.counts | awk '{print $2}' | cat - <(echo "<s>"; echo "</s>" ) > $dir/wordlist
heldout_sent=10000 # Don't change this if you want result to be comparable with
# kaldi_lm results
mkdir -p $dir
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \
head -$heldout_sent > $dir/heldout
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \
tail -n +$heldout_sent > $dir/train
ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \
-map-unk "<UNK>" -kndiscount -interpolate -lm $dir/lm.arpa
ngram -lm $dir/lm.arpa -ppl $dir/heldout
\ No newline at end of file
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0
#! /usr/bin/env bash
#!/bin/bash
stage=-1
stop_stage=100
......
../../s0/local/download_lm_ch.sh
\ No newline at end of file
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 3 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path"
......@@ -13,7 +13,7 @@ ckpt_path_prefix=$2
jit_model_export_path=$3
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
......@@ -9,15 +9,17 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
ckpt_name=$(basename ${ckpt_prefxi})
mkdir -p exp
chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
chunk_mode=true
fi
# download language model
#bash local/download_lm_ch.sh
......@@ -28,7 +30,12 @@ mkdir -p exp
for type in attention ctc_greedy_search; do
echo "decoding ${type}"
batch_size=64
if [ ${chunk_mode} == true ];then
# stream decoding only support batchsize=1
batch_size=1
else
batch_size=64
fi
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/test.py \
......
#!/bin/bash
set -eo pipefail
stage=-1
stop_stage=100
corpus=aishell
lmtype=srilm
source utils/parse_options.sh
data=${MAIN_ROOT}/examples/dataset/${corpus}
lexicon=$data/resource_aishell/lexicon.txt
text=$data/data_aishell/transcript/aishell_transcript_v0.8.txt
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# 7.1 Prepare dict
unit_file=data/vocab.txt
mkdir -p data/local/dict
cp $unit_file data/local/dict/units.txt
utils/fst/prepare_dict.py \
--unit_file $unit_file \
--in_lexicon ${lexicon} \
--out_lexicon data/local/dict/lexicon.txt
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# 7.2 Train lm
lm=data/local/lm
mkdir -p data/train
mkdir -p $lm
utils/manifest_key_value.py \
--manifest_path data/manifest.train \
--output_path data/train
utils/filter_scp.pl data/train/text \
$text > $lm/text
if [ $lmtype == 'srilm' ];then
local/aishell_train_lms.sh
else
utils/ngram_train.sh --order 3 $lm/text $lm/lm.arpa
fi
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# 7.3 Build decoding TLG
utils/fst/compile_lexicon_token_fst.sh \
data/local/dict data/local/tmp data/local/lang
utils/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1;
fi
echo "Aishell build TLG done."
exit 0
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
......@@ -12,7 +12,7 @@ config_path=$1
ckpt_name=$2
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
echo "using ${device}..."
......
export MAIN_ROOT=${PWD}/../../../
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
# model exp
MODEL=u2
export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin
# srilm
export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs
export SRILM=${MAIN_ROOT}/tools/srilm
export PATH=${PATH}:${SRILM}/bin:${SRILM}/bin/i686-m64
# Kaldi
export KALDI_ROOT=${MAIN_ROOT}/tools/kaldi
[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh
export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH
[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!"
. $KALDI_ROOT/tools/config/common_path.sh || true
......@@ -25,15 +25,26 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
./local/avg.sh exp/${ckpt}/checkpoints ${avg_num}
avg.sh exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=4 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi
# Optionally, you can add LM and test it with runtime.
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
# train lm and build TLG
./local/tlg.sh --corpus aishell --lmtype srilm
fi
../../../utils
\ No newline at end of file
# MandarinK8
## Conformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 45.73 M | conf/conformer.yaml | spec_aug + shift | test | attention | 2.1794936656951904 | 0.102304 |
| conformer | 45.73 M | conf/conformer.yaml | spec_aug + shift | test | ctc_greedy_search | 2.1794936656951904 | 0.084295 |
| conformer | 45.73 M | conf/conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | 2.1794936656951904 | 0.084340 |
| conformer | 45.73 M | conf/conformer.yaml | spec_aug + shift | test | attention_rescoring | 2.1794936656951904 | 0.081675 |
## Chunk Conformer
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 45.73 M | conf/chunk_conformer.yaml | spec_aug + shift | test | attention | 16, -1 | 2.23287845 | 0.087982 |
| conformer | 45.73 M | conf/chunk_conformer.yaml | spec_aug + shift | test | ctc_greedy_search | 16, -1 | 2.23287845 | 0.086962 |
| conformer | 45.73 M | conf/chunk_conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | 16, -1 | 2.23287845 | 0.086741 |
| conformer | 45.73 M | conf/chunk_conformer.yaml | spec_aug + shift | test | attention_rescoring | 16, -1 | 2.23287845 | 0.083495 |
[
{
"type": "speed",
"params": {
"min_speed_rate": 0.9,
"max_speed_rate": 1.1,
"num_rates": 3
},
"prob": 0.0
},
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
},
{
"type": "specaug",
"params": {
"F": 10,
"T": 50,
"n_freq_masks": 2,
"n_time_masks": 2,
"p": 1.0,
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
},
"prob": 1.0
}
]
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.5
max_input_len: 20.0 # second
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'char'
spm_model_prefix: ''
augmentation_config: conf/augmentation.json
batch_size: 32
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 8000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
use_cnn_module: True
cnn_module_kernel: 15
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
training:
n_epoch: 240
accum_grad: 4
global_grad_clip: 5.0
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1e-6
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 128
error_rate_type: cer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 10
cutoff_prob: 1.0
cutoff_top_n: 0
num_proc_bsearch: 8
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: true # simulate streaming inference. Defaults to False.
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.5
max_input_len: 20.0 # second
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.0
max_output_input_ratio: .inf
collator:
vocab_filepath: data/vocab.txt
unit_type: 'char'
spm_model_prefix: ''
augmentation_config: conf/augmentation.json
batch_size: 32
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 8000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
use_cnn_module: True
cnn_module_kernel: 15
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
training:
n_epoch: 100 # 50 will be lowest
accum_grad: 4
global_grad_clip: 5.0
optim: adam
optim_conf:
lr: 0.002
weight_decay: 1e-6
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 128
error_rate_type: cer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 10
cutoff_prob: 1.0
cutoff_top_n: 0
num_proc_bsearch: 8
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
ckpt_name=$(basename ${ckpt_prefxi})
mkdir -p exp
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0
#! /usr/bin/env bash
stage=-1
stop_stage=100
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
for dataset in train dev test; do
mv data/manifest.${dataset} data/manifest.${dataset}.raw
done
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# download data, generate manifests
# build vocabulary
python3 ${MAIN_ROOT}/utils/build_vocab.py \
--unit_type="char" \
--count_threshold=0 \
--vocab_path="data/vocab.txt" \
--manifest_paths "data/manifest.train.raw"
if [ $? -ne 0 ]; then
echo "Build vocabulary failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# compute mean and stddev for normalizer
num_workers=$(nproc)
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.train.raw" \
--specgram_type="fbank" \
--feat_dim=80 \
--delta_delta=false \
--stride_ms=10.0 \
--window_ms=25.0 \
--sample_rate=8000 \
--use_dB_normalization=False \
--num_samples=-1 \
--num_workers=${num_workers} \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for dataset in train dev test; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type "char" \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${dataset}.raw" \
--output_path="data/manifest.${dataset}"
if [ $? -ne 0 ]; then
echo "Formt mnaifest failed. Terminated."
exit 1
fi
} &
done
wait
fi
echo "data preparation done."
exit 0
#!/bin/bash
. ${MAIN_ROOT}/utils/utility.sh
DIR=data/lm
mkdir -p ${DIR}
URL='https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm'
MD5="29e02312deb2e59b3c8686c7966d4fe3"
TARGET=${DIR}/zh_giga.no_cna_cmn.prune01244.klm
echo "Download language model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download the language model!"
exit 1
fi
exit 0
#! /usr/bin/env bash
if [ $# != 3 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_path_prefix=$2
jit_model_export_path=$3
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
python3 -u ${BIN_DIR}/export.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \
--export_path ${jit_model_export_path}
if [ $? -ne 0 ]; then
echo "Failed in export!"
exit 1
fi
exit 0
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
ckpt_name=$(basename ${ckpt_prefxi})
mkdir -p exp
# download language model
#bash local/download_lm_ch.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
for type in attention ctc_greedy_search; do
echo "decoding ${type}"
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}"
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
exit 0
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
echo "using ${device}..."
mkdir -p exp
python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
if [ $? -ne 0 ]; then
echo "Failed in training!"
exit 1
fi
exit 0
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=u2
export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin
#!/bin/bash
set -e
source path.sh
stage=0
stop_stage=100
conf_path=conf/conformer.yaml
avg_num=20
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
bash ./local/data.sh || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=0,1,2,3 ./local/train.sh ${conf_path} ${ckpt}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
avg.sh exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=4 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi
export MAIN_ROOT=${PWD}/../../
export MAIN_ROOT=`realpath ${PWD}/../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
......
export MAIN_ROOT=${PWD}/../../
export MAIN_ROOT=`realpath ${PWD}/../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
......
*.tgz
manifest.*
*.meta
aidatatang_200zh/
\ No newline at end of file
# [Aidatatang_200zh](http://www.openslr.org/62/)
Aidatatang_200zh is a free Chinese Mandarin speech corpus provided by Beijing DataTang Technology Co., Ltd under Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License.
The contents and the corresponding descriptions of the corpus include:
* The corpus contains 200 hours of acoustic data, which is mostly mobile recorded data.
* 600 speakers from different accent areas in China are invited to participate in the recording.
* The transcription accuracy for each sentence is larger than 98%.
* Recordings are conducted in a quiet indoor environment.
* The database is divided into training set, validation set, and testing set in a ratio of 7: 1: 2.
* Detail information such as speech data coding and speaker information is preserved in the metadata file.
* Segmented transcripts are also provided.
The corpus aims to support researchers in speech recognition, machine translation, voiceprint recognition, and other speech-related fields. Therefore, the corpus is totally free for academic use.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare aidatatang_200zh mandarin dataset
Download, unpack and create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import argparse
import codecs
import json
import os
import soundfile
from utils.utility import download
from utils.utility import unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'http://www.openslr.org/resources/62'
# URL_ROOT = 'https://openslr.magicdatatech.com/resources/62'
DATA_URL = URL_ROOT + '/aidatatang_200zh.tgz'
MD5_DATA = '6e0f4f39cd5f667a7ee53c397c8d0949'
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default=DATA_HOME + "/aidatatang_200zh",
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
transcript_path = os.path.join(data_dir, 'transcript',
'aidatatang_200_zh_transcript.txt')
transcript_dict = {}
for line in codecs.open(transcript_path, 'r', 'utf-8'):
line = line.strip()
if line == '':
continue
audio_id, text = line.split(' ', 1)
# remove withespace, charactor text
text = ''.join(text.split())
transcript_dict[audio_id] = text
data_types = ['train', 'dev', 'test']
for dtype in data_types:
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
audio_dir = os.path.join(data_dir, 'corpus/', dtype)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
if not fname.endswith('.wav'):
continue
audio_path = os.path.abspath(os.path.join(subfolder, fname))
audio_id = os.path.basename(fname)[:-4]
audio_data, samplerate = soundfile.read(audio_path)
duration = float(len(audio_data) / samplerate)
text = transcript_dict[audio_id]
json_lines.append(
json.dumps(
{
'utt': audio_id,
'feat': audio_path,
'feat_shape': (duration, ), # second
'text': text,
},
ensure_ascii=False))
total_sec += duration
total_text += len(text)
total_num += 1
manifest_path = manifest_path_prefix + '.' + dtype
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
manifest_dir = os.path.dirname(manifest_path_prefix)
meta_path = os.path.join(manifest_dir, dtype) + '.meta'
with open(meta_path, 'w') as f:
print(f"{dtype}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path, subset):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, subset)
if not os.path.exists(data_dir):
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir)
# unpack all audio tar files
audio_dir = os.path.join(data_dir, 'corpus')
for subfolder, dirlist, filelist in sorted(os.walk(audio_dir)):
for sub in dirlist:
print(f"unpack dir {sub}...")
for folder, _, filelist in sorted(
os.walk(os.path.join(subfolder, sub))):
for ftar in filelist:
unpack(os.path.join(folder, ftar), folder, True)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
create_manifest(data_dir, manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(
url=DATA_URL,
md5sum=MD5_DATA,
target_dir=args.target_dir,
manifest_path=args.manifest_prefix,
subset='aidatatang_200zh')
print("Data download and manifest prepare done!")
if __name__ == '__main__':
main()
data_aishell*
*.meta
manifest.*
*.tgz
resource_aishell
# [Aishell1](http://www.openslr.org/33/)
This Open Source Mandarin Speech Corpus, AISHELL-ASR0009-OS1, is 178 hours long. It is a part of AISHELL-ASR0009, of which utterance contains 11 domains, including smart home, autonomous driving, and industrial production. The whole recording was put in quiet indoor environment, using 3 different devices at the same time: high fidelity microphone (44.1kHz, 16-bit,); Android-system mobile phone (16kHz, 16-bit), iOS-system mobile phone (16kHz, 16-bit). Audios in high fidelity were re-sampled to 16kHz to build AISHELL- ASR0009-OS1. 400 speakers from different accent areas in China were invited to participate in the recording. The manual transcription accuracy rate is above 95%, through professional speech annotation and strict quality inspection. The corpus is divided into training, development and testing sets. ( This database is free for academic research, not in the commerce, if without permission. )
......@@ -31,9 +31,11 @@ from utils.utility import unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'http://www.openslr.org/resources/33'
URL_ROOT = 'https://openslr.magicdatatech.com/resources/33'
# URL_ROOT = 'https://openslr.magicdatatech.com/resources/33'
DATA_URL = URL_ROOT + '/data_aishell.tgz'
MD5_DATA = '2f494334227864a8a8fec932999db9d8'
RESOURCE_URL = URL_ROOT + '/resource_aishell.tgz'
MD5_RESOURCE = '957d480a0fcac85fc18e550756f624e5'
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
......@@ -60,18 +62,22 @@ def create_manifest(data_dir, manifest_path_prefix):
if line == '':
continue
audio_id, text = line.split(' ', 1)
# remove withespace
# remove withespace, charactor text
text = ''.join(text.split())
transcript_dict[audio_id] = text
data_types = ['train', 'dev', 'test']
for dtype in data_types:
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
audio_dir = os.path.join(data_dir, 'wav', dtype)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
audio_path = os.path.join(subfolder, fname)
audio_id = fname[:-4]
audio_path = os.path.abspath(os.path.join(subfolder, fname))
audio_id = os.path.basename(fname)[:-4]
# if no transcription for audio then skipped
if audio_id not in transcript_dict:
continue
......@@ -81,22 +87,34 @@ def create_manifest(data_dir, manifest_path_prefix):
json_lines.append(
json.dumps(
{
'utt':
os.path.splitext(os.path.basename(audio_path))[0],
'feat':
audio_path,
'utt': audio_id,
'feat': audio_path,
'feat_shape': (duration, ), # second
'text':
text
'text': text
},
ensure_ascii=False))
total_sec += duration
total_text += len(text)
total_num += 1
manifest_path = manifest_path_prefix + '.' + dtype
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
manifest_dir = os.path.dirname(manifest_path_prefix)
meta_path = os.path.join(manifest_dir, dtype) + '.meta'
with open(meta_path, 'w') as f:
print(f"{dtype}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path):
def prepare_dataset(url, md5sum, target_dir, manifest_path=None):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'data_aishell')
if not os.path.exists(data_dir):
......@@ -110,7 +128,9 @@ def prepare_dataset(url, md5sum, target_dir, manifest_path):
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
create_manifest(data_dir, manifest_path)
if manifest_path:
create_manifest(data_dir, manifest_path)
def main():
......@@ -123,6 +143,14 @@ def main():
target_dir=args.target_dir,
manifest_path=args.manifest_prefix)
prepare_dataset(
url=RESOURCE_URL,
md5sum=MD5_RESOURCE,
target_dir=args.target_dir,
manifest_path=None)
print("Data download and manifest prepare done!")
if __name__ == '__main__':
main()
# [Aishell3](http://www.openslr.org/93/)
AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpus which could be used to train multi-speaker Text-to-Speech (TTS) systems. The corpus contains roughly **85 hours** of emotion-neutral recordings spoken by 218 native Chinese mandarin speakers and total 88035 utterances. Their auxiliary attributes such as gender, age group and native accents are explicitly marked and provided in the corpus. Accordingly, transcripts in Chinese character-level and pinyin-level are provided along with the recordings. The word & tone transcription accuracy rate is above 98%, through professional speech annotation and strict quality inspection for tone and prosody. ( This database is free for academic research, not in the commerce, if without permission. )
# [GigaSpeech](https://github.com/SpeechColab/GigaSpeech)
```
git clone https://github.com/SpeechColab/GigaSpeech.git
cd GigaSpeech
utils/gigaspeech_download.sh /disk1/audio_data/gigaspeech
toolkits/kaldi/gigaspeech_data_prep.sh --train-subset XL /disk1/audio_data/gigaspeech ../data
cd ..
```
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/bash
set -e
curdir=$PWD
test -d GigaSpeech || git clone https://github.com/SpeechColab/GigaSpeech.git
pushd GigaSpeech
source env_vars.sh
./utils/download_gigaspeech.sh ${curdir}/
#toolkits/kaldi/gigaspeech_data_prep.sh --train-subset XL /disk1/audio_data/gigaspeech ../data
popd
dev-clean/
dev-other/
test-clean/
test-other/
train-clean-100/
train-clean-360/
train-other-500/
dev-clean
dev-other
test-clean
test-other
train-clean-100
train-clean-360
train-other-500
*.meta
manifest.*
......@@ -77,6 +77,10 @@ def create_manifest(data_dir, manifest_path):
"""
print("Creating manifest %s ..." % manifest_path)
json_lines = []
total_sec = 0.0
total_text = 0.0
total_num = 0
for subfolder, _, filelist in sorted(os.walk(data_dir)):
text_filelist = [
filename for filename in filelist if filename.endswith('trans.txt')
......@@ -86,7 +90,9 @@ def create_manifest(data_dir, manifest_path):
for line in io.open(text_filepath, encoding="utf8"):
segments = line.strip().split()
text = ' '.join(segments[1:]).lower()
audio_filepath = os.path.join(subfolder, segments[0] + '.flac')
audio_filepath = os.path.abspath(
os.path.join(subfolder, segments[0] + '.flac'))
audio_data, samplerate = soundfile.read(audio_filepath)
duration = float(len(audio_data)) / samplerate
json_lines.append(
......@@ -99,10 +105,27 @@ def create_manifest(data_dir, manifest_path):
'text':
text
}))
total_sec += duration
total_text += len(text)
total_num += 1
with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
for line in json_lines:
out_file.write(line + '\n')
subset = os.path.splitext(manifest_path)[1][1:]
manifest_dir = os.path.dirname(manifest_path)
data_dir_name = os.path.split(data_dir)[-1]
meta_path = os.path.join(manifest_dir, data_dir_name) + '.meta'
with open(meta_path, 'w') as f:
print(f"{subset}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path):
"""Download, unpack and create summmary manifest file.
......
# [MagicData](http://www.openslr.org/68/)
MAGICDATA Mandarin Chinese Read Speech Corpus was developed by MAGIC DATA Technology Co., Ltd. and freely published for non-commercial use.
The contents and the corresponding descriptions of the corpus include:
* The corpus contains 755 hours of speech data, which is mostly mobile recorded data.
* 1080 speakers from different accent areas in China are invited to participate in the recording.
* The sentence transcription accuracy is higher than 98%.
* Recordings are conducted in a quiet indoor environment.
* The database is divided into training set, validation set, and testing set in a ratio of 51: 1: 2.
* Detail information such as speech data coding and speaker information is preserved in the metadata file.
* The domain of recording texts is diversified, including interactive Q&A, music search, SNS messages, home command and control, etc.
* Segmented transcripts are also provided.
The corpus aims to support researchers in speech recognition, machine translation, speaker recognition, and other speech-related fields. Therefore, the corpus is totally free for academic use.
......@@ -2,3 +2,4 @@ dev-clean/
manifest.dev-clean
manifest.train-clean
train-clean/
*.meta
......@@ -58,6 +58,10 @@ def create_manifest(data_dir, manifest_path):
"""
print("Creating manifest %s ..." % manifest_path)
json_lines = []
total_sec = 0.0
total_text = 0.0
total_num = 0
for subfolder, _, filelist in sorted(os.walk(data_dir)):
text_filelist = [
filename for filename in filelist if filename.endswith('trans.txt')
......@@ -80,10 +84,27 @@ def create_manifest(data_dir, manifest_path):
'text':
text
}))
total_sec += duration
total_text += len(text)
total_num += 1
with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
for line in json_lines:
out_file.write(line + '\n')
subset = os.path.splitext(manifest_path)[1][1:]
manifest_dir = os.path.dirname(manifest_path)
data_dir_name = os.path.split(data_dir)[-1]
meta_path = os.path.join(manifest_dir, data_dir_name) + '.meta'
with open(meta_path, 'w') as f:
print(f"{subset}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path):
"""Download, unpack and create summmary manifest file.
......
# multi-cn
This is a Chinese speech recognition recipe that trains on all Chinese corpora on OpenSLR, including:
* Aidatatang (140 hours)
* Aishell (151 hours)
* MagicData (712 hours)
* Primewords (99 hours)
* ST-CMDS (110 hours)
* THCHS-30 (26 hours)
* optional AISHELL2 (~1000 hours) if available
# [Primewords](http://www.openslr.org/47/)
This free Chinese Mandarin speech corpus set is released by Shanghai Primewords Information Technology Co., Ltd.
The corpus is recorded by smart mobile phones from 296 native Chinese speakers. The transcription accuracy is larger than 98%, at the confidence level of 95%. It is free for academic use.
The mapping between the transcript and utterance is given in JSON format.
# [FreeST](http://www.openslr.org/38/)
*.tar.gz.*
manifest.*
*.md
EN-ZH/
train-split/
test-segment/
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare Ted-En-Zh speech translation dataset
Create manifest files from splited datased.
dev set: tst2010, test set: tst2015
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import argparse
import codecs
import json
import os
import soundfile
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--src_dir",
default="",
type=str,
help="Directory to kaldi splited data. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
data_types_infos = [
('train', 'train-split/train-segment', 'En-Zh/train.en-zh'),
('dev', 'test-segment/tst2010', 'En-Zh/tst2010.en-zh'),
('test', 'test-segment/tst2015', 'En-Zh/tst2015.en-zh')
]
for data_info in data_types_infos:
dtype, audio_relative_dir, text_relative_path = data_info
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
text_path = os.path.join(data_dir, text_relative_path)
audio_dir = os.path.join(data_dir, audio_relative_dir)
for line in codecs.open(text_path, 'r', 'utf-8', errors='ignore'):
line = line.strip()
if len(line) < 1:
continue
audio_id, trancription, translation = line.split('\t')
utt = audio_id.split('.')[0]
audio_path = os.path.join(audio_dir, audio_id)
if os.path.exists(audio_path):
if os.path.getsize(audio_path) < 30000:
continue
audio_data, samplerate = soundfile.read(audio_path)
duration = float(len(audio_data) / samplerate)
json_lines.append(
json.dumps(
{
'utt': utt,
'feat': audio_path,
'feat_shape': (duration, ), # second
'text': " ".join(translation.split()),
'text1': " ".join(trancription.split())
},
ensure_ascii=False))
total_sec += duration
total_text += len(translation.split())
total_num += 1
if not total_num % 1000:
print(dtype, 'Processed:', total_num)
manifest_path = manifest_path_prefix + '.' + dtype + '.raw'
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
def prepare_dataset(src_dir, manifest_path=None):
"""create manifest file."""
if os.path.isdir(manifest_path):
manifest_path = os.path.join(manifest_path, 'manifest')
if manifest_path:
create_manifest(src_dir, manifest_path)
def main():
if args.src_dir.startswith('~'):
args.src_dir = os.path.expanduser(args.src_dir)
prepare_dataset(src_dir=args.src_dir, manifest_path=args.manifest_prefix)
print("manifest prepare done!")
if __name__ == '__main__':
main()
*.tgz
manifest.*
data_thchs30
resource
test-noise
*.meta
# [THCHS30](http://www.openslr.org/18/)
This is the *data part* of the `THCHS30 2015` acoustic data
& scripts dataset.
The dataset is described in more detail in the paper ``THCHS-30 : A Free
Chinese Speech Corpus`` by Dong Wang, Xuewei Zhang.
A paper (if it can be called a paper) 13 years ago regarding the database:
Dong Wang, Dalei Wu, Xiaoyan Zhu, ``TCMSD: A new Chinese Continuous Speech Database``,
International Conference on Chinese Computing (ICCC'01), 2001, Singapore.
The layout of this data pack is the following:
``data``
``*.wav``
audio data
``*.wav.trn``
transcriptions
``{train,dev,test}``
contain symlinks into the ``data`` directory for both audio and
transcription files. Contents of these directories define the
train/dev/test split of the data.
``{lm_word}``
``word.3gram.lm``
trigram LM based on word
``lexicon.txt``
lexicon based on word
``{lm_phone}``
``phone.3gram.lm``
trigram LM based on phone
``lexicon.txt``
lexicon based on phone
``README.TXT``
this file
Data statistics
===============
Statistics for the data are as follows:
=========== ========== ========== ===========
**dataset** **audio** **#sents** **#words**
=========== ========== ========== ===========
train 25 10,000 198,252
dev 2:14 893 17,743
test 6:15 2,495 49,085
=========== ========== ========== ===========
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare THCHS-30 mandarin dataset
Download, unpack and create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import argparse
import codecs
import json
import os
from multiprocessing.pool import Pool
from pathlib import Path
import soundfile
from utils.utility import download
from utils.utility import unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'http://www.openslr.org/resources/18'
# URL_ROOT = 'https://openslr.magicdatatech.com/resources/18'
DATA_URL = URL_ROOT + '/data_thchs30.tgz'
TEST_NOISE_URL = URL_ROOT + '/test-noise.tgz'
RESOURCE_URL = URL_ROOT + '/resource.tgz'
MD5_DATA = '2d2252bde5c8429929e1841d4cb95e90'
MD5_TEST_NOISE = '7e8a985fb965b84141b68c68556c2030'
MD5_RESOURCE = 'c0b2a565b4970a0c4fe89fefbf2d97e1'
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default=DATA_HOME + "/THCHS30",
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def read_trn(filepath):
"""read trn file.
word text in first line.
syllable text in second line.
phoneme text in third line.
Args:
filepath (str): trn path.
Returns:
list(str): (word, syllable, phone)
"""
texts = []
with open(filepath, 'r') as f:
lines = f.read().strip().split('\n')
assert len(lines) == 3, lines
# charactor text, remove withespace
texts.append(''.join(lines[0].split()))
texts.extend(lines[1:])
return texts
def resolve_symlink(filepath):
"""resolve symlink which content is norm file.
Args:
filepath (str): norm file symlink.
"""
sym_path = Path(filepath)
relative_link = sym_path.read_text().strip()
relative = Path(relative_link)
relpath = sym_path.parent / relative
return relpath.resolve()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
data_types = ['train', 'dev', 'test']
for dtype in data_types:
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
audio_dir = os.path.join(data_dir, dtype)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
file_path = os.path.join(subfolder, fname)
if file_path.endswith('.wav'):
audio_path = os.path.abspath(file_path)
text_path = resolve_symlink(audio_path + '.trn')
else:
continue
assert os.path.exists(audio_path) and os.path.exists(text_path)
audio_id = os.path.basename(audio_path)[:-4]
word_text, syllable_text, phone_text = read_trn(text_path)
audio_data, samplerate = soundfile.read(audio_path)
duration = float(len(audio_data) / samplerate)
# not dump alignment infos
json_lines.append(
json.dumps(
{
'utt': audio_id,
'feat': audio_path,
'feat_shape': (duration, ), # second
'text': word_text, # charactor
'syllable': syllable_text,
'phone': phone_text,
},
ensure_ascii=False))
total_sec += duration
total_text += len(word_text)
total_num += 1
manifest_path = manifest_path_prefix + '.' + dtype
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
manifest_dir = os.path.dirname(manifest_path_prefix)
meta_path = os.path.join(manifest_dir, dtype) + '.meta'
with open(meta_path, 'w') as f:
print(f"{dtype}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path, subset):
"""Download, unpack and create manifest file."""
datadir = os.path.join(target_dir, subset)
if not os.path.exists(datadir):
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
if subset == 'data_thchs30':
create_manifest(datadir, manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
tasks = [
(DATA_URL, MD5_DATA, args.target_dir, args.manifest_prefix,
"data_thchs30"),
(TEST_NOISE_URL, MD5_TEST_NOISE, args.target_dir, args.manifest_prefix,
"test-noise"),
(RESOURCE_URL, MD5_RESOURCE, args.target_dir, args.manifest_prefix,
"resource"),
]
with Pool(7) as pool:
pool.starmap(prepare_dataset, tasks)
print("Data download and manifest prepare done!")
if __name__ == '__main__':
main()
TIMIT.*
TIMIT
manifest.*
*.meta
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare Librispeech ASR datasets.
Download, unpack and create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import argparse
import codecs
import json
import os
import re
import string
from pathlib import Path
import soundfile
from utils.utility import unzip
URL_ROOT = ""
MD5_DATA = "45c68037c7fdfe063a43c851f181fb2d"
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default='~/.cache/paddle/dataset/speech/timit',
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
#: A string containing Chinese punctuation marks (non-stops).
non_stops = (
# Fullwidth ASCII variants
'\uFF02\uFF03\uFF04\uFF05\uFF06\uFF07\uFF08\uFF09\uFF0A\uFF0B\uFF0C\uFF0D'
'\uFF0F\uFF1A\uFF1B\uFF1C\uFF1D\uFF1E\uFF20\uFF3B\uFF3C\uFF3D\uFF3E\uFF3F'
'\uFF40\uFF5B\uFF5C\uFF5D\uFF5E\uFF5F\uFF60'
# Halfwidth CJK punctuation
'\uFF62\uFF63\uFF64'
# CJK symbols and punctuation
'\u3000\u3001\u3003'
# CJK angle and corner brackets
'\u3008\u3009\u300A\u300B\u300C\u300D\u300E\u300F\u3010\u3011'
# CJK brackets and symbols/punctuation
'\u3014\u3015\u3016\u3017\u3018\u3019\u301A\u301B\u301C\u301D\u301E\u301F'
# Other CJK symbols
'\u3030'
# Special CJK indicators
'\u303E\u303F'
# Dashes
'\u2013\u2014'
# Quotation marks and apostrophe
'\u2018\u2019\u201B\u201C\u201D\u201E\u201F'
# General punctuation
'\u2026\u2027'
# Overscores and underscores
'\uFE4F'
# Small form variants
'\uFE51\uFE54'
# Latin punctuation
'\u00B7')
#: A string of Chinese stops.
stops = (
'\uFF01' # Fullwidth exclamation mark
'\uFF1F' # Fullwidth question mark
'\uFF61' # Halfwidth ideographic full stop
'\u3002' # Ideographic full stop
)
#: A string containing all Chinese punctuation.
punctuation = non_stops + stops
def tn(text):
# lower text
text = text.lower()
# remove punc
text = re.sub(f'[{punctuation}{string.punctuation}]', "", text)
return text
def read_txt(filepath: str) -> str:
with open(filepath, 'r') as f:
line = f.read().strip().split(maxsplit=2)[2]
return tn(line)
def read_algin(filepath: str) -> str:
"""read word or phone alignment file.
<start-sample> <end-sample> <token><newline>
Args:
filepath (str): [description]
Returns:
str: token sepearte by <space>
"""
aligns = [] # (start, end, token)
with open(filepath, 'r') as f:
for line in f:
items = line.strip().split()
# for phone: (Note: beginning and ending silence regions are marked with h#)
if items[2].strip() == 'h#':
continue
aligns.append(items)
return ' '.join([item[2] for item in aligns])
def create_manifest(data_dir, manifest_path_prefix):
"""Create a manifest json file summarizing the data set, with each line
containing the meta data (i.e. audio filepath, transcription text, audio
duration) of each audio file within the data set.
"""
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
utts = set()
data_types = ['TRAIN', 'TEST']
for dtype in data_types:
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
audio_dir = Path(os.path.join(data_dir, dtype))
for fname in sorted(audio_dir.rglob('*.WAV')):
audio_path = fname.resolve() # .WAV
audio_id = audio_path.stem
# if uttid exits, then skipped
if audio_id in utts:
continue
utts.add(audio_id)
text_path = audio_path.with_suffix('.TXT')
phone_path = audio_path.with_suffix('.PHN')
word_path = audio_path.with_suffix('.WRD')
audio_data, samplerate = soundfile.read(
str(audio_path), dtype='int16')
duration = float(len(audio_data) / samplerate)
word_text = read_txt(text_path)
phone_text = read_algin(phone_path)
gender_spk = str(audio_path.parent.stem)
spk = gender_spk[1:]
gender = gender_spk[0]
utt_id = '_'.join([spk, gender, audio_id])
# not dump alignment infos
json_lines.append(
json.dumps(
{
'utt': utt_id,
'feat': str(audio_path),
'feat_shape': (duration, ), # second
'text': word_text, # word
'phone': phone_text,
'spk': spk,
'gender': gender,
},
ensure_ascii=False))
total_sec += duration
total_text += len(word_text.split())
total_num += 1
manifest_path = manifest_path_prefix + '.' + dtype.lower()
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
manifest_dir = os.path.dirname(manifest_path_prefix)
meta_path = os.path.join(manifest_dir, dtype.lower()) + '.meta'
with open(meta_path, 'w') as f:
print(f"{dtype}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path):
"""Download, unpack and create summmary manifest file.
"""
filepath = os.path.join(target_dir, "TIMIT.zip")
if not os.path.exists(filepath):
print(f"Please download TIMIT.zip into {target_dir}.")
raise FileNotFoundError
if not os.path.exists(os.path.join(target_dir, "TIMIT")):
# check md5sum
assert check_md5sum(filepath, md5sum)
# unpack
unzip(filepath, target_dir)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
# create manifest json file
create_manifest(os.path.join(target_dir, "TIMIT"), manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(URL_ROOT, MD5_DATA, args.target_dir, args.manifest_prefix)
print("Data download and manifest prepare done!")
if __name__ == '__main__':
main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare TIMIT dataset (Standard split from Kaldi)
Create manifest files from splited datased.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import argparse
import codecs
import json
import os
import soundfile
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--src_dir",
default="",
type=str,
help="Directory to kaldi splited data. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
data_types = ['train', 'dev', 'test']
for dtype in data_types:
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
phn_path = os.path.join(data_dir, dtype + '.text')
phn_dict = {}
for line in codecs.open(phn_path, 'r', 'utf-8'):
line = line.strip()
if line == '':
continue
audio_id, text = line.split(' ', 1)
phn_dict[audio_id] = text
audio_dir = os.path.join(data_dir, dtype + '_sph.scp')
for line in codecs.open(audio_dir, 'r', 'utf-8'):
audio_id, audio_path = line.strip().split()
# if no transcription for audio then raise error
assert audio_id in phn_dict
audio_data, samplerate = soundfile.read(audio_path)
duration = float(len(audio_data) / samplerate)
text = phn_dict[audio_id]
json_lines.append(
json.dumps(
{
'utt': audio_id,
'feat': audio_path,
'feat_shape': (duration, ), # second
'text': text
},
ensure_ascii=False))
total_sec += duration
total_text += len(text)
total_num += 1
manifest_path = manifest_path_prefix + '.' + dtype + '.raw'
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
def prepare_dataset(src_dir, manifest_path=None):
"""create manifest file."""
if os.path.isdir(manifest_path):
manifest_path = os.path.join(manifest_path, 'manifest')
if manifest_path:
create_manifest(src_dir, manifest_path)
def main():
if args.src_dir.startswith('~'):
args.src_dir = os.path.expanduser(args.src_dir)
prepare_dataset(src_dir=args.src_dir, manifest_path=args.manifest_prefix)
print("manifest prepare done!")
if __name__ == '__main__':
main()
# ASR
* s0 is for deepspeech2
* s0 is for deepspeech2 offline
* s1 is for transformer/conformer/U2
* s2 is for transformer/conformer/U2 w/ kaldi feat
need install Kaldi
......@@ -2,8 +2,9 @@
## Deepspeech2
| Model | release | Config | Test set | Loss | WER |
| --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 2.1.0 | conf/deepspeech2.yaml | 15.184467315673828 | test-clean | 0.072154 |
| DeepSpeech2 | 2.0.0 | conf/deepspeech2.yaml | - | test-clean | 0.073973 |
| DeepSpeech2 | 1.8.5 | - | test-clean | - | 0.074939 |
| Model | Params | release | Config | Test set | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 42.96M | 2.2.0 | conf/deepspeech2.yaml + spec_aug | 14.49190807 | test-clean | 0.067283 |
| DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | 15.184467315673828 | test-clean | 0.072154 |
| DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | - | test-clean | 0.073973 |
| DeepSpeech2 | 42.96M | 1.8.5 | - | test-clean | - | 0.074939 |
......@@ -15,5 +15,20 @@
"max_shift_ms": 5
},
"prob": 1.0
},
{
"type": "specaug",
"params": {
"F": 10,
"T": 50,
"n_freq_masks": 2,
"n_time_masks": 2,
"p": 1.0,
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
},
"prob": 1.0
}
]
......@@ -3,16 +3,21 @@ data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev-clean
test_manifest: data/manifest.test-clean
mean_std_filepath: data/mean_std.json
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
batch_size: 20
min_input_len: 0.0
max_input_len: 27.0 # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
collator:
batch_size: 20
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear
target_sample_rate: 16000
max_freq: None
......@@ -27,7 +32,7 @@ data:
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
num_workers: 2
model:
num_conv_layers: 2
......@@ -43,6 +48,9 @@ training:
weight_decay: 1e-06
global_grad_clip: 5.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 128
......
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev-clean
test_manifest: data/manifest.test-clean
min_input_len: 0.0
max_input_len: 27.0 # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
collator:
batch_size: 20
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 20.0
delta_delta: False
dither: 1.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
model:
num_conv_layers: 2
num_rnn_layers: 3
rnn_layer_size: 2048
rnn_direction: forward
num_fc_layers: 2
fc_layers_size_list: 512, 256
use_gru: False
training:
n_epoch: 50
lr: 1e-3
lr_decay: 0.83
weight_decay: 1e-06
global_grad_clip: 5.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 128
error_rate_type: wer
decoding_method: ctc_beam_search
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 1.9
beta: 0.3
beam_size: 500
cutoff_prob: 1.0
cutoff_top_n: 40
num_proc_bsearch: 8
#! /usr/bin/env bash
#!/bin/bash
stage=-1
stop_stage=100
......
#! /usr/bin/env bash
#!/bin/bash
. ${MAIN_ROOT}/utils/utility.sh
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 3 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path"
if [ $# != 4 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path model_type"
exit -1
fi
......@@ -11,9 +11,10 @@ echo "using $ngpu gpus..."
config_path=$1
ckpt_path_prefix=$2
jit_model_export_path=$3
model_type=$4
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
......@@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
--nproc ${ngpu} \
--config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \
--export_path ${jit_model_export_path}
--export_path ${jit_model_export_path} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in export!"
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1
fi
......@@ -9,11 +9,12 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
model_type=$3
# download language model
bash local/download_lm_en.sh
......@@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix}
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
if [ $# != 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type"
exit -1
fi
......@@ -10,9 +10,10 @@ echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
model_type=$3
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
echo "using ${device}..."
......@@ -23,7 +24,8 @@ python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
--output exp/${ckpt_name} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in training!"
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 1 ];then
echo "usage: tune ckpt_path"
......
export MAIN_ROOT=${PWD}/../../../
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
......
......@@ -6,6 +6,7 @@ stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=30
model_type=offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num}
......@@ -19,20 +20,20 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt}
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt} ${model_type}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
./local/avg.sh exp/${ckpt}/checkpoints ${avg_num}
avg.sh exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}
fi
# LibriSpeech
## Data
| Data Subset | Duration in Seconds |
| --- | --- |
| data/manifest.train | 0.83s ~ 29.735s |
| data/manifest.dev | 1.065 ~ 35.155s |
| data/manifest.test-clean | 1.285s ~ 34.955s |
## Conformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention | 6.35 | 0.030162 |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | 6.35 | 0.037910 |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | 6.35 | 0.037761 |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | 6.35 | 0.032115 |
### Test w/o length filter
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean-all | attention | 6.35 | 0.057117 |
## Chunk Conformer
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | attention | 16, -1 | 7.11 | 0.063193 |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | 16, -1 | 7.11 | 0.082394 |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | 16, -1 | 7.11 | 0.082156 |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | 16, -1 | 7.11 | 0.071000 |
| Model | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- |
| conformer | conf/conformer.yaml | spec_aug + shift | test-all | attention | test-all 6.35 | 0.057117 |
| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | attention | test-all 6.35 | 0.030162 |
| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | test-all 6.35 | 0.037910 |
| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | test-all 6.35 | 0.037761 |
| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | test-all 6.35 | 0.032115 |
## Transformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean | attention | 6.98 | 0.036 |
| Model | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- |
| transformer | conf/transformer.yaml | spec_aug + shift | test-all | attention | test-all 6.98 | 0.066500 |
| transformer | conf/transformer.yaml | spec_aug + shift | test-clean | attention | test-all 6.98 | 0.036 |
### Test w/o length filter
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean-all | attention | 6.98 | 0.066500 |
......@@ -3,18 +3,20 @@ data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 4
min_input_len: 0.5
max_input_len: 20.0
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 16
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
......@@ -79,8 +81,8 @@ model:
training:
n_epoch: 120
accum_grad: 1
n_epoch: 240
accum_grad: 8
global_grad_clip: 5.0
optim: adam
optim_conf:
......@@ -91,6 +93,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
......
......@@ -3,18 +3,20 @@ data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
......@@ -84,6 +86,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
......@@ -103,6 +108,6 @@ decoding:
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.
simulate_streaming: true # simulate streaming inference. Defaults to False.
......@@ -3,18 +3,20 @@ data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 16
min_input_len: 0.5 # seconds
max_input_len: 20.0 # seconds
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 32
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
......@@ -76,7 +78,7 @@ model:
training:
n_epoch: 120
accum_grad: 8
accum_grad: 4
global_grad_clip: 3.0
optim: adam
optim_conf:
......@@ -87,6 +89,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
......
......@@ -3,18 +3,20 @@ data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
......@@ -82,6 +84,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
......
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0
#! /usr/bin/env bash
#!/bin/bash
stage=-1
stop_stage=100
......
#! /usr/bin/env bash
#!/bin/bash
. ${MAIN_ROOT}/utils/utility.sh
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 3 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path"
......@@ -13,7 +13,7 @@ ckpt_path_prefix=$2
jit_model_export_path=$3
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
......@@ -9,12 +9,20 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
chunk_mode=true
fi
echo "chunk mode ${chunk_mode}"
# download language model
#bash local/download_lm_en.sh
#if [ $? -ne 0 ]; then
......@@ -23,7 +31,12 @@ ckpt_prefix=$2
for type in attention ctc_greedy_search; do
echo "decoding ${type}"
batch_size=64
if [ ${chunk_mode} == true ];then
# stream decoding only support batchsize=1
batch_size=1
else
batch_size=64
fi
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
......@@ -12,7 +12,7 @@ config_path=$1
ckpt_name=$2
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
echo "using ${device}..."
......
export MAIN_ROOT=${PWD}/../../../
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export PATH=${MAIN_ROOT}:${PWD}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
......
......@@ -19,20 +19,25 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=4,5,6,7 ./local/train.sh ${conf_path} ${ckpt}
CUDA_VISIBLE_DEVICES=0,1,2,3 ./local/train.sh ${conf_path} ${ckpt}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
./local/avg.sh exp/${ckpt}/checkpoints ${avg_num}
avg.sh exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi
../../../utils
\ No newline at end of file
# LibriSpeech
## Data
| Data Subset | Duration in Seconds |
| data/manifest.train | 0.83s ~ 29.735s |
| data/manifest.dev | 1.065 ~ 35.155s |
| data/manifest.test-clean | 1.285s ~ 34.955s |
## Conformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention | - | - |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | | |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | | |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | | |
### Test w/o length filter
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean-all | attention | | |
## Chunk Conformer
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | attention | 16, -1 | | |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | 16, -1 | | |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | 16, -1 | | - |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | 16, -1 | | - |
## Transformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean | attention | | |
### Test w/o length filter
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean-all | attention | | |
[
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
},
{
"type": "speed",
"params": {
"min_speed_rate": 0.9,
"max_speed_rate": 1.1,
"num_rates": 3
},
"prob": 0.0
},
{
"type": "specaug",
"params": {
"F": 10,
"T": 50,
"n_freq_masks": 2,
"n_time_masks": 2,
"p": 1.0,
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
},
"prob": 1.0
}
]
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.5
max_input_len: 20.0
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 16
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
use_cnn_module: True
cnn_module_kernel: 15
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
causal: True
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
training:
n_epoch: 240
accum_grad: 8
global_grad_clip: 5.0
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 128
error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 10
cutoff_prob: 1.0
cutoff_top_n: 0
num_proc_bsearch: 8
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: true # simulate streaming inference. Defaults to False.
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file_type: "json"
# encoder related
encoder: transformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
use_dynamic_chunk: true
use_dynamic_left_chunk: false
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
training:
n_epoch: 120
accum_grad: 1
global_grad_clip: 5.0
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 64
error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 10
cutoff_prob: 1.0
cutoff_top_n: 0
num_proc_bsearch: 8
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: true # simulate streaming inference. Defaults to False.
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
min_input_len: 0.5 # seconds
max_input_len: 20.0 # seconds
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 16
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
use_cnn_module: True
cnn_module_kernel: 15
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
training:
n_epoch: 120
accum_grad: 8
global_grad_clip: 3.0
optim: adam
optim_conf:
lr: 0.004
weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 64
error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 10
cutoff_prob: 1.0
cutoff_top_n: 0
num_proc_bsearch: 8
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file_type: "json"
# encoder related
encoder: transformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
training:
n_epoch: 120
accum_grad: 2
global_grad_clip: 5.0
optim: adam
optim_conf:
lr: 0.004
weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 64
error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 10
cutoff_prob: 1.0
cutoff_top_n: 0
num_proc_bsearch: 8
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0
#!/bin/bash
stage=-1
stop_stage=100
# bpemode (unigram or bpe)
nbpe=5000
bpemode=unigram
bpeprefix="data/bpe_${bpemode}_${nbpe}"
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR}
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data, generate manifests
python3 ${TARGET_DIR}/librispeech/librispeech.py \
--manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/librispeech" \
--full_download="True"
if [ $? -ne 0 ]; then
echo "Prepare LibriSpeech failed. Terminated."
exit 1
fi
for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mv data/manifest.${set} data/manifest.${set}.raw
done
rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw
for set in train-clean-100 train-clean-360 train-other-500; do
cat data/manifest.${set}.raw >> data/manifest.train.raw
done
for set in dev-clean dev-other; do
cat data/manifest.${set}.raw >> data/manifest.dev.raw
done
for set in test-clean test-other; do
cat data/manifest.${set}.raw >> data/manifest.test.raw
done
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# build vocabulary
python3 ${MAIN_ROOT}/utils/build_vocab.py \
--unit_type "spm" \
--spm_vocab_size=${nbpe} \
--spm_mode ${bpemode} \
--spm_model_prefix ${bpeprefix} \
--vocab_path="data/vocab.txt" \
--manifest_paths="data/manifest.train.raw"
if [ $? -ne 0 ]; then
echo "Build vocabulary failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# compute mean and stddev for normalizer
num_workers=$(nproc)
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.train.raw" \
--num_samples=-1 \
--specgram_type="fbank" \
--feat_dim=80 \
--delta_delta=false \
--sample_rate=16000 \
--stride_ms=10.0 \
--window_ms=25.0 \
--use_dB_normalization=False \
--num_workers=${num_workers} \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for set in train dev test dev-clean dev-other test-clean test-other; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type "spm" \
--spm_model_prefix ${bpeprefix} \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${set}.raw" \
--output_path="data/manifest.${set}"
if [ $? -ne 0 ]; then
echo "Formt mnaifest failed. Terminated."
exit 1
fi
}&
done
wait
fi
echo "LibriSpeech Data preparation done."
exit 0
#!/bin/bash
. ${MAIN_ROOT}/utils/utility.sh
DIR=data/lm
mkdir -p ${DIR}
URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm
MD5="099a601759d467cd0a8523ff939819c5"
TARGET=${DIR}/common_crawl_00.prune01111.trie.klm
echo "Download language model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download the language model!"
exit 1
fi
exit 0
#!/bin/bash
if [ $# != 3 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_path_prefix=$2
jit_model_export_path=$3
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
python3 -u ${BIN_DIR}/export.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \
--export_path ${jit_model_export_path}
if [ $? -ne 0 ]; then
echo "Failed in export!"
exit 1
fi
exit 0
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
chunk_mode=true
fi
echo "chunk mode ${chunk_mode}"
# download language model
#bash local/download_lm_en.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
for type in attention ctc_greedy_search; do
echo "decoding ${type}"
if [ ${chunk_mode} == true ];then
# stream decoding only support batchsize=1
batch_size=1
else
batch_size=64
fi
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}"
batch_size=1
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
exit 0
#!/bin/bash
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
echo "using ${device}..."
mkdir -p exp
python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
if [ $? -ne 0 ]; then
echo "Failed in training!"
exit 1
fi
exit 0
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${PWD}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=u2
export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin
#!/bin/bash
set -e
source path.sh
stage=0
stop_stage=100
conf_path=conf/transformer.yaml
avg_num=30
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
bash ./local/data.sh || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=0,1,2,3 ./local/train.sh ${conf_path} ${ckpt}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
avg.sh exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi
../../../utils/
\ No newline at end of file
export MAIN_ROOT=${PWD}/../../
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
......@@ -7,4 +7,4 @@ export LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=/usr/local/lib/:${LD_LIBRARY_PATH}
\ No newline at end of file
export LD_LIBRARY_PATH=/usr/local/lib/:${LD_LIBRARY_PATH}
export MAIN_ROOT=${PWD}/../../
export MAIN_ROOT=`realpath ${PWD}/../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
......
# TED En -> Zh
* t0 for u2 speech translation
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train.tiny
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.05 # second
max_input_len: 30.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.01
max_output_input_ratio: 20.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: data/bpe_unigram_8000
mean_std_filepath: ""
# augmentation_config: conf/augmentation.json
batch_size: 10
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file_type: "json"
# encoder related
encoder: transformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
asr_weight: 0.0
ctc_weight: 0.0
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
training:
n_epoch: 120
accum_grad: 2
global_grad_clip: 5.0
optim: adam
optim_conf:
lr: 0.004
weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 5
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 5
error_rate_type: char-bleu
decoding_method: fullsentence # 'fullsentence', 'simultaneous'
alpha: 2.5
beta: 0.3
beam_size: 10
cutoff_prob: 1.0
cutoff_top_n: 0
num_proc_bsearch: 8
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.05 # second
max_input_len: 30.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.01
max_output_input_ratio: 20.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: data/bpe_unigram_8000
mean_std_filepath: ""
# augmentation_config: conf/augmentation.json
batch_size: 10
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file_type: "json"
# encoder related
encoder: transformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
asr_weight: 0.5
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
training:
n_epoch: 120
accum_grad: 2
global_grad_clip: 5.0
optim: adam
optim_conf:
lr: 2.5
weight_decay: 1e-06
scheduler: noam
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 5
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 5
error_rate_type: char-bleu
decoding_method: fullsentence # 'fullsentence', 'simultaneous'
alpha: 2.5
beta: 0.3
beam_size: 10
cutoff_prob: 1.0
cutoff_top_n: 0
num_proc_bsearch: 8
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.
#!/bin/bash
stage=-1
stop_stage=100
# bpemode (unigram or bpe)
nbpe=8000
bpemode=unigram
bpeprefix="data/bpe_${bpemode}_${nbpe}"
DATA_DIR=
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR}
if [ ! -d ${SOURCE_DIR} ]; then
echo "Error: Dataset is not avaiable. Please download and unzip the dataset"
echo "Download Link: https://pan.baidu.com/s/18L-59wgeS96WkObISrytQQ Passwd: bva0"
echo "The tree of the directory should be:"
echo "."
echo "|-- En-Zh"
echo "|-- test-segment"
echo " |-- tst2010"
echo " |-- ..."
echo "|-- train-split"
echo " |-- train-segment"
echo "|-- README.md"
exit 1
fi
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# generate manifests
python3 ${TARGET_DIR}/ted_en_zh/ted_en_zh.py \
--manifest_prefix="data/manifest" \
--src_dir="${DATA_DIR}"
echo "Complete raw data pre-process."
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# build vocabulary
python3 ${MAIN_ROOT}/utils/build_vocab.py \
--unit_type "spm" \
--spm_vocab_size=${nbpe} \
--spm_mode ${bpemode} \
--spm_model_prefix ${bpeprefix} \
--vocab_path="data/vocab.txt" \
--text_keys 'text' 'text1' \
--manifest_paths="data/manifest.train.raw"
if [ $? -ne 0 ]; then
echo "Build vocabulary failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# compute mean and stddev for normalizer
num_workers=$(nproc)
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.train.raw" \
--num_samples=-1 \
--specgram_type="fbank" \
--feat_dim=80 \
--delta_delta=false \
--sample_rate=16000 \
--stride_ms=10.0 \
--window_ms=25.0 \
--use_dB_normalization=False \
--num_workers=${num_workers} \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for set in train dev test; do
{
python3 ${MAIN_ROOT}/utils/format_triplet_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type "spm" \
--spm_model_prefix ${bpeprefix} \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${set}.raw" \
--output_path="data/manifest.${set}"
if [ $? -ne 0 ]; then
echo "Formt mnaifest failed. Terminated."
exit 1
fi
}&
done
wait
fi
echo "Ted En-Zh Data preparation done."
exit 0
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
for type in fullsentence; do
echo "decoding ${type}"
batch_size=32
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
exit 0
#!/bin/bash
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
echo "using ${device}..."
mkdir -p exp
python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
if [ $? -ne 0 ]; then
echo "Failed in training!"
exit 1
fi
exit 0
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=u2_st
export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin
#!/bin/bash
set -e
source path.sh
stage=0
stop_stage=100
conf_path=conf/transformer_joint_noam.yaml
avg_num=5
data_path=./TED-En-Zh # path to unzipped data
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
bash ./local/data.sh --DATA_DIR ${data_path} || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
../../utils/avg.sh exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi
# Regular expression based text normalization for Chinese
For simplicity and ease of implementation, text normalization is basically done by rules and dictionaries. Here's an example.
今天的最低气温达到-10°C.
只要有33/4的人同意,就可以通过决议。
1945年5月2日,苏联士兵在德国国会大厦上升起了胜利旗,象征着攻占柏林并战胜了纳粹德国。
4月16日,清晨的战斗以炮击揭幕,数以千计的大炮和喀秋莎火箭炮开始炮轰德军阵地,炮击持续了数天之久。
如果剩下的30.6%是过去,那么还有69.4%.
事情发生在2020/03/31的上午8:00.
警方正在找一支.22口径的手枪。
欢迎致电中国联通,北京2022年冬奥会官方合作伙伴为您服务
充值缴费请按1,查询话费及余量请按2,跳过本次提醒请按井号键。
快速解除流量封顶请按星号键,腾讯王卡产品介绍、使用说明、特权及活动请按9,查询话费、套餐余量、积分及活动返款请按1,手机上网流量开通及取消请按2,查询本机号码及本号所使用套餐请按4,密码修改及重置请按5,紧急开机请按6,挂失请按7,查询充值记录请按8,其它自助服务及人工服务请按0
智能客服助理快速查话费、查流量请按9,了解北京联通业务请按1,宽带IPTV新装、查询请按2,障碍报修请按3,充值缴费请按4,投诉建议请按5,政企业务请按7,人工服务请按0,for english severice press star key
您的帐户当前可用余额为63.89元,本月消费为2.17元。您的消费、套餐余量和其它信息将以短信形式下发,请您注意查收。谢谢使用,再见!。
您的帐户当前可用余额为负15.5元,本月消费为59.6元。您的消费、套餐余量和其它信息将以短信形式下发,请您注意查收。谢谢使用,再见!。
尊敬的客户,您目前的话费余额为负14.60元,已低于10元,为保证您的通信畅通,请及时缴纳费用。
您的流量已用完,为避免您产生额外费用,建议您根据需求开通一个流量包以作补充。
您可以直接说,查询话费及余量、开通流量包、缴费,您也可以说出其它需求,请问有什么可以帮您?
您的账户当前可用余额为负36.00元,本月消费36.00元。
请问你是电话13985608526的机主吗?
如您对处理结果不满意,可拨打中国联通集团投诉电话10015进行投诉,按本地通话费收费,返回自助服务请按井号键
“26314”号VIP客服代表为您服务。
尊敬的5G用户,欢迎您致电中国联通
首先是应用了M1芯片的iPad Pro,新款的iPad Pro支持5G,这也是苹果的第二款5G产品线。
除此之外,摄像头方面再次升级,增加了前摄全新超广角摄像头,支持人物居中功能,搭配超广角可实现视频中始终让人物居中效果。
屏幕方面,iPad Pro 12.9版本支持XDR体验的Mini-LEDS显示屏,支持HDR10、杜比视界,还支持杜比全景声。
iPad Pro的秒控键盘这次也推出白色版本。
售价方面,11英寸版本售价799美元起,12.9英寸售价1099美元起。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from text_processing import normalization
parser = argparse.ArgumentParser(
description="Normalize text in Chinese with some rules.")
parser.add_argument("input", type=str, help="the input sentences")
parser.add_argument("output", type=str, help="path to save the output file.")
args = parser.parse_args()
with open(args.input, 'rt') as fin:
with open(args.output, 'wt') as fout:
for sent in fin:
sent = normalization.normalize_sentence(sent.strip())
fout.write(sent)
fout.write('\n')
export MAIN_ROOT=`realpath ${PWD}/../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${MAIN_ROOT}/third_party:${PYTHONPATH}#
#!/usr/bin/env bash
source path.sh
stage=-1
stop_stage=100
exp_dir=exp
data_dir=data
filename="sentences.txt"
source ${MAIN_ROOT}/utils/parse_options.sh || exit -1
mkdir -p ${exp_dir}
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "stage 1: Processing "
python3 local/test_normalization.py ${data_dir}/${filename} ${exp_dir}/normalized.txt
if [ -f "${exp_dir}/normalized.txt" ]; then
echo "Normalized text save at ${exp_dir}/normalized.txt"
fi
# TODO(chenfeiyu): compute edit distance against ground-truth
fi
echo "done"
exit 0
# thchs30
* a0 for mfa alignment
# THCHS-30 数据集强制对齐实验
-----
本实验对 THCHS-30 中文数据集用 [Montreal-Forced-Aligner](https://montreal-forced-aligner.readthedocs.io/en/latest/index.html) 进行强制对齐。
THCHS-30 的文本标注数据分为:
1. 汉字级别(word),该数据集用空格对词进行了划分,我们在使用时按照将不同字之间按空格划分
2. 音节级别(syllable),即汉语中的一个拼音
3. 音素级别(phone),一个拼音有多个音素组成,汉语的声母韵母可以理解为音素,不同的数据集有各自的音素标准,THCHS-30 数据集与标贝 BZNSYP 数据集的音素标准略有不同
数据 A11_0 文本示例如下:
```
绿 是 阳春 烟 景 大块 文章 的 底色 四月 的 林 峦 更是 绿 得 鲜活 秀媚 诗意 盎然↩
lv4 shi4 yang2 chun1 yan1 jing3 da4 kuai4 wen2 zhang1 de5 di3 se4 si4 yue4 de5 lin2 luan2 geng4 shi4 lv4 de5 xian1 huo2 xiu4 mei4 shi1 yi4 ang4 ran2↩
l v4 sh ix4 ii iang2 ch un1 ii ian1 j ing3 d a4 k uai4 uu un2 zh ang1 d e5 d i3 s e4 s iy4 vv ve4 d e5 l in2 l uan2 g eng4 sh ix4 l v4 d e5 x ian1 h uo2 x iu4 m ei4 sh ix1 ii i4 aa ang4 r an2
```
## 开始实验
---
在本项目的 根目录/tools 执行
```
make
```
下载 MFA 的可执行包(也会同时下载本项目所需的其他工具)
执行如下命令:
```
cd a0
./run.sh
```
应用程序会自动下载 THCHS-30数据集,处理成 MFA 所需的文件格式并开始训练,您可以修改 `run.sh` 中的参数 `LEXICON_NAME` 来决定您需要强制对齐的级别(word、syllable 和 phone)
## MFA 所使用的字典
---
MFA 字典的格式请参考: [MFA 官方文档 Dictionary format ](https://montreal-forced-aligner.readthedocs.io/en/latest/dictionary.html)
phone.lexicon 直接使用的是 `THCHS-30/data_thchs30/lm_phone/lexicon.txt`
word.lexicon 考虑到了中文的多音字,使用**带概率的字典**, 生成规则请参考 `local/gen_word2phone.py`
`syllable.lexicon` 获取自 [DNSun/thchs30-pinyin2tone](https://github.com/DNSun/thchs30-pinyin2tone)
## 对齐结果
---
我们提供了三种级别 MFA 训练好的对齐结果、模型和字典(`syllable.lexicon``data/dict` 中,`phone.lexicon`` word.lexicon` 运行数据预处理代码后会自动从原始数据集复制或生成)
**phone 级别:** [phone.lexicon](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/phone/phone.lexicon)[对齐结果](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/phone/thchs30_alignment.tar.gz)[模型](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/phone/thchs30_model.zip)
**syllabel 级别:** [syllable.lexicon](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/syllable/syllable.lexicon)[对齐结果](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/syllable/thchs30_alignment.tar.gz)[模型](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/syllable/thchs30_model.zip)
**word 级别:** [word.lexicon](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/word/word.lexicon)[对齐结果](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/word/thchs30_alignment.tar.gz)[模型](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/word/thchs30_model.zip)
随后,您可以参考 [MFA 官方文档 Align using pretrained models](https://montreal-forced-aligner.readthedocs.io/en/stable/aligning.html#align-using-pretrained-models) 使用我们给您提供好的模型直接对自己的数据集进行强制对齐,注意,您需要使用和模型对应的 lexicon 文件,当文本是汉字时,您需要用空格把不同的**汉字**(而不是词语)分开
A0 aa a0
A1 aa a1
A2 aa a2
A3 aa a3
A4 aa a4
AI0 aa ai0
AI1 aa ai1
AI2 aa ai2
AI3 aa ai3
AI4 aa ai4
AN0 aa an0
AN1 aa an1
AN2 aa an2
AN3 aa an3
AN4 aa an4
ANG0 aa ang0
ANG1 aa ang1
ANG2 aa ang2
ANG3 aa ang3
ANG4 aa ang4
AO0 aa ao0
AO1 aa ao1
AO2 aa ao2
AO3 aa ao3
AO4 aa ao4
BA0 b a0
BA1 b a1
BA2 b a2
BA3 b a3
BA4 b a4
BAI0 b ai0
BAI1 b ai1
BAI2 b ai2
BAI3 b ai3
BAI4 b ai4
BAN0 b an0
BAN1 b an1
BAN2 b an2
BAN3 b an3
BAN4 b an4
BANG0 b ang0
BANG1 b ang1
BANG2 b ang2
BANG3 b ang3
BANG4 b ang4
BAO0 b ao0
BAO1 b ao1
BAO2 b ao2
BAO3 b ao3
BAO4 b ao4
BEI0 b ei0
BEI1 b ei1
BEI2 b ei2
BEI3 b ei3
BEI4 b ei4
BEN0 b en0
BEN1 b en1
BEN2 b en2
BEN3 b en3
BEN4 b en4
BENG0 b eng0
BENG1 b eng1
BENG2 b eng2
BENG3 b eng3
BENG4 b eng4
BI0 b i0
BI1 b i1
BI2 b i2
BI3 b i3
BI4 b i4
BIAN0 b ian0
BIAN1 b ian1
BIAN2 b ian2
BIAN3 b ian3
BIAN4 b ian4
BIAO0 b iao0
BIAO1 b iao1
BIAO2 b iao2
BIAO3 b iao3
BIAO4 b iao4
BIE0 b ie0
BIE1 b ie1
BIE2 b ie2
BIE3 b ie3
BIE4 b ie4
BIN0 b in0
BIN1 b in1
BIN2 b in2
BIN3 b in3
BIN4 b in4
BING0 b ing0
BING1 b ing1
BING2 b ing2
BING3 b ing3
BING4 b ing4
BO0 b o0
BO1 b o1
BO2 b o2
BO3 b o3
BO4 b o4
BU0 b u0
BU1 b u1
BU2 b u2
BU3 b u3
BU4 b u4
CA0 c a0
CA1 c a1
CA2 c a2
CA3 c a3
CA4 c a4
CAI0 c ai0
CAI1 c ai1
CAI2 c ai2
CAI3 c ai3
CAI4 c ai4
CAN0 c an0
CAN1 c an1
CAN2 c an2
CAN3 c an3
CAN4 c an4
CANG0 c ang0
CANG1 c ang1
CANG2 c ang2
CANG3 c ang3
CANG4 c ang4
CAO0 c ao0
CAO1 c ao1
CAO2 c ao2
CAO3 c ao3
CAO4 c ao4
CE0 c e0
CE1 c e1
CE2 c e2
CE3 c e3
CE4 c e4
CEN0 c en0
CEN1 c en1
CEN2 c en2
CEN3 c en3
CEN4 c en4
CENG0 c eng0
CENG1 c eng1
CENG2 c eng2
CENG3 c eng3
CENG4 c eng4
CHA0 ch a0
CHA1 ch a1
CHA2 ch a2
CHA3 ch a3
CHA4 ch a4
CHAI0 ch ai0
CHAI1 ch ai1
CHAI2 ch ai2
CHAI3 ch ai3
CHAI4 ch ai4
CHAN0 ch an0
CHAN1 ch an1
CHAN2 ch an2
CHAN3 ch an3
CHAN4 ch an4
CHANG0 ch ang0
CHANG1 ch ang1
CHANG2 ch ang2
CHANG3 ch ang3
CHANG4 ch ang4
CHAO0 ch ao0
CHAO1 ch ao1
CHAO2 ch ao2
CHAO3 ch ao3
CHAO4 ch ao4
CHE0 ch e0
CHE1 ch e1
CHE2 ch e2
CHE3 ch e3
CHE4 ch e4
CHEN0 ch en0
CHEN1 ch en1
CHEN2 ch en2
CHEN3 ch en3
CHEN4 ch en4
CHENG0 ch eng0
CHENG1 ch eng1
CHENG2 ch eng2
CHENG3 ch eng3
CHENG4 ch eng4
CHI0 ch ix0
CHI1 ch ix1
CHI2 ch ix2
CHI3 ch ix3
CHI4 ch ix4
CHONG0 ch ong0
CHONG1 ch ong1
CHONG2 ch ong2
CHONG3 ch ong3
CHONG4 ch ong4
CHOU0 ch ou0
CHOU1 ch ou1
CHOU2 ch ou2
CHOU3 ch ou3
CHOU4 ch ou4
CHU0 ch u0
CHU1 ch u1
CHU2 ch u2
CHU3 ch u3
CHU4 ch u4
CHUAI0 ch uai0
CHUAI1 ch uai1
CHUAI2 ch uai2
CHUAI3 ch uai3
CHUAI4 ch uai4
CHUAN0 ch uan0
CHUAN1 ch uan1
CHUAN2 ch uan2
CHUAN3 ch uan3
CHUAN4 ch uan4
CHUANG0 ch uang0
CHUANG1 ch uang1
CHUANG2 ch uang2
CHUANG3 ch uang3
CHUANG4 ch uang4
CHUI0 ch ui0
CHUI1 ch ui1
CHUI2 ch ui2
CHUI3 ch ui3
CHUI4 ch ui4
CHUN0 ch un0
CHUN1 ch un1
CHUN2 ch un2
CHUN3 ch un3
CHUN4 ch un4
CHUO0 ch uo0
CHUO1 ch uo1
CHUO2 ch uo2
CHUO3 ch uo3
CHUO4 ch uo4
CI0 c iy0
CI1 c iy1
CI2 c iy2
CI3 c iy3
CI4 c iy4
CONG0 c ong0
CONG1 c ong1
CONG2 c ong2
CONG3 c ong3
CONG4 c ong4
COU0 c ou0
COU1 c ou1
COU2 c ou2
COU3 c ou3
COU4 c ou4
CU0 c u0
CU1 c u1
CU2 c u2
CU3 c u3
CU4 c u4
CUAN0 c uan0
CUAN1 c uan1
CUAN2 c uan2
CUAN3 c uan3
CUAN4 c uan4
CUI0 c ui0
CUI1 c ui1
CUI2 c ui2
CUI3 c ui3
CUI4 c ui4
CUN0 c un0
CUN1 c un1
CUN2 c un2
CUN3 c un3
CUN4 c un4
CUO0 c uo0
CUO1 c uo1
CUO2 c uo2
CUO3 c uo3
CUO4 c uo4
DA0 d a0
DA1 d a1
DA2 d a2
DA3 d a3
DA4 d a4
DAI0 d ai0
DAI1 d ai1
DAI2 d ai2
DAI3 d ai3
DAI4 d ai4
DAN0 d an0
DAN1 d an1
DAN2 d an2
DAN3 d an3
DAN4 d an4
DANG0 d ang0
DANG1 d ang1
DANG2 d ang2
DANG3 d ang3
DANG4 d ang4
DAO0 d ao0
DAO1 d ao1
DAO2 d ao2
DAO3 d ao3
DAO4 d ao4
DE0 d e0
DE1 d e1
DE2 d e2
DE3 d e3
DE4 d e4
DEI0 d ei0
DEI1 d ei1
DEI2 d ei2
DEI3 d ei3
DEI4 d ei4
DEN0 d en0
DEN1 d en1
DEN2 d en2
DEN3 d en3
DEN4 d en4
DENG0 d eng0
DENG1 d eng1
DENG2 d eng2
DENG3 d eng3
DENG4 d eng4
DI0 d i0
DI1 d i1
DI2 d i2
DI3 d i3
DI4 d i4
DIA0 d ia0
DIA1 d ia1
DIA2 d ia2
DIA3 d ia3
DIA4 d ia4
DIAN0 d ian0
DIAN1 d ian1
DIAN2 d ian2
DIAN3 d ian3
DIAN4 d ian4
DIAO0 d iao0
DIAO1 d iao1
DIAO2 d iao2
DIAO3 d iao3
DIAO4 d iao4
DIE0 d ie0
DIE1 d ie1
DIE2 d ie2
DIE3 d ie3
DIE4 d ie4
DING0 d ing0
DING1 d ing1
DING2 d ing2
DING3 d ing3
DING4 d ing4
DIU0 d iu0
DIU1 d iu1
DIU2 d iu2
DIU3 d iu3
DIU4 d iu4
DONG0 d ong0
DONG1 d ong1
DONG2 d ong2
DONG3 d ong3
DONG4 d ong4
DOU0 d ou0
DOU1 d ou1
DOU2 d ou2
DOU3 d ou3
DOU4 d ou4
DU0 d u0
DU1 d u1
DU2 d u2
DU3 d u3
DU4 d u4
DUAN0 d uan0
DUAN1 d uan1
DUAN2 d uan2
DUAN3 d uan3
DUAN4 d uan4
DUI0 d ui0
DUI1 d ui1
DUI2 d ui2
DUI3 d ui3
DUI4 d ui4
DUN0 d un0
DUN1 d un1
DUN2 d un2
DUN3 d un3
DUN4 d un4
DUO0 d uo0
DUO1 d uo1
DUO2 d uo2
DUO3 d uo3
DUO4 d uo4
E0 ee e0
E1 ee e1
E2 ee e2
E3 ee e3
E4 ee e4
EN0 ee en0
EN1 ee en1
EN2 ee en2
EN3 ee en3
EN4 ee en4
ER0 ee er0
ER1 ee er1
ER2 ee er2
ER3 ee er3
ER4 ee er4
FA0 f a0
FA1 f a1
FA2 f a2
FA3 f a3
FA4 f a4
FAN0 f an0
FAN1 f an1
FAN2 f an2
FAN3 f an3
FAN4 f an4
FANG0 f ang0
FANG1 f ang1
FANG2 f ang2
FANG3 f ang3
FANG4 f ang4
FEI0 f ei0
FEI1 f ei1
FEI2 f ei2
FEI3 f ei3
FEI4 f ei4
FEN0 f en0
FEN1 f en1
FEN2 f en2
FEN3 f en3
FEN4 f en4
FENG0 f eng0
FENG1 f eng1
FENG2 f eng2
FENG3 f eng3
FENG4 f eng4
FO0 f o0
FO1 f o1
FO2 f o2
FO3 f o3
FO4 f o4
FOU0 f ou0
FOU1 f ou1
FOU2 f ou2
FOU3 f ou3
FOU4 f ou4
FU0 f u0
FU1 f u1
FU2 f u2
FU3 f u3
FU4 f u4
GA0 g a0
GA1 g a1
GA2 g a2
GA3 g a3
GA4 g a4
GAI0 g ai0
GAI1 g ai1
GAI2 g ai2
GAI3 g ai3
GAI4 g ai4
GAN0 g an0
GAN1 g an1
GAN2 g an2
GAN3 g an3
GAN4 g an4
GANG0 g ang0
GANG1 g ang1
GANG2 g ang2
GANG3 g ang3
GANG4 g ang4
GAO0 g ao0
GAO1 g ao1
GAO2 g ao2
GAO3 g ao3
GAO4 g ao4
GE0 g e0
GE1 g e1
GE2 g e2
GE3 g e3
GE4 g e4
GEI0 g ei0
GEI1 g ei1
GEI2 g ei2
GEI3 g ei3
GEI4 g ei4
GEN0 g en0
GEN1 g en1
GEN2 g en2
GEN3 g en3
GEN4 g en4
GENG0 g eng0
GENG1 g eng1
GENG2 g eng2
GENG3 g eng3
GENG4 g eng4
GONG0 g ong0
GONG1 g ong1
GONG2 g ong2
GONG3 g ong3
GONG4 g ong4
GOU0 g ou0
GOU1 g ou1
GOU2 g ou2
GOU3 g ou3
GOU4 g ou4
GU0 g u0
GU1 g u1
GU2 g u2
GU3 g u3
GU4 g u4
GUA0 g ua0
GUA1 g ua1
GUA2 g ua2
GUA3 g ua3
GUA4 g ua4
GUAI0 g uai0
GUAI1 g uai1
GUAI2 g uai2
GUAI3 g uai3
GUAI4 g uai4
GUAN0 g uan0
GUAN1 g uan1
GUAN2 g uan2
GUAN3 g uan3
GUAN4 g uan4
GUANG0 g uang0
GUANG1 g uang1
GUANG2 g uang2
GUANG3 g uang3
GUANG4 g uang4
GUI0 g ui0
GUI1 g ui1
GUI2 g ui2
GUI3 g ui3
GUI4 g ui4
GUN0 g un0
GUN1 g un1
GUN2 g un2
GUN3 g un3
GUN4 g un4
GUO0 g uo0
GUO1 g uo1
GUO2 g uo2
GUO3 g uo3
GUO4 g uo4
HA0 h a0
HA1 h a1
HA2 h a2
HA3 h a3
HA4 h a4
HAI0 h ai0
HAI1 h ai1
HAI2 h ai2
HAI3 h ai3
HAI4 h ai4
HAN0 h an0
HAN1 h an1
HAN2 h an2
HAN3 h an3
HAN4 h an4
HANG0 h ang0
HANG1 h ang1
HANG2 h ang2
HANG3 h ang3
HANG4 h ang4
HAO0 h ao0
HAO1 h ao1
HAO2 h ao2
HAO3 h ao3
HAO4 h ao4
HE0 h e0
HE1 h e1
HE2 h e2
HE3 h e3
HE4 h e4
HEI0 h ei0
HEI1 h ei1
HEI2 h ei2
HEI3 h ei3
HEI4 h ei4
HEN0 h en0
HEN1 h en1
HEN2 h en2
HEN3 h en3
HEN4 h en4
HENG0 h eng0
HENG1 h eng1
HENG2 h eng2
HENG3 h eng3
HENG4 h eng4
HONG0 h ong0
HONG1 h ong1
HONG2 h ong2
HONG3 h ong3
HONG4 h ong4
HOU0 h ou0
HOU1 h ou1
HOU2 h ou2
HOU3 h ou3
HOU4 h ou4
HU0 h u0
HU1 h u1
HU2 h u2
HU3 h u3
HU4 h u4
HUA0 h ua0
HUA1 h ua1
HUA2 h ua2
HUA3 h ua3
HUA4 h ua4
HUAI0 h uai0
HUAI1 h uai1
HUAI2 h uai2
HUAI3 h uai3
HUAI4 h uai4
HUAN0 h uan0
HUAN1 h uan1
HUAN2 h uan2
HUAN3 h uan3
HUAN4 h uan4
HUANG0 h uang0
HUANG1 h uang1
HUANG2 h uang2
HUANG3 h uang3
HUANG4 h uang4
HUI0 h ui0
HUI1 h ui1
HUI2 h ui2
HUI3 h ui3
HUI4 h ui4
HUN0 h un0
HUN1 h un1
HUN2 h un2
HUN3 h un3
HUN4 h un4
HUO0 h uo0
HUO1 h uo1
HUO2 h uo2
HUO3 h uo3
HUO4 h uo4
JI0 j i0
JI1 j i1
JI2 j i2
JI3 j i3
JI4 j i4
JIA0 j ia0
JIA1 j ia1
JIA2 j ia2
JIA3 j ia3
JIA4 j ia4
JIAN0 j ian0
JIAN1 j ian1
JIAN2 j ian2
JIAN3 j ian3
JIAN4 j ian4
JIANG0 j iang0
JIANG1 j iang1
JIANG2 j iang2
JIANG3 j iang3
JIANG4 j iang4
JIAO0 j iao0
JIAO1 j iao1
JIAO2 j iao2
JIAO3 j iao3
JIAO4 j iao4
JIE0 j ie0
JIE1 j ie1
JIE2 j ie2
JIE3 j ie3
JIE4 j ie4
JIN0 j in0
JIN1 j in1
JIN2 j in2
JIN3 j in3
JIN4 j in4
JING0 j ing0
JING1 j ing1
JING2 j ing2
JING3 j ing3
JING4 j ing4
JIONG0 j iong0
JIONG1 j iong1
JIONG2 j iong2
JIONG3 j iong3
JIONG4 j iong4
JIU0 j iu0
JIU1 j iu1
JIU2 j iu2
JIU3 j iu3
JIU4 j iu4
JU0 j v0
JU1 j v1
JU2 j v2
JU3 j v3
JU4 j v4
JUAN0 j van0
JUAN1 j van1
JUAN2 j van2
JUAN3 j van3
JUAN4 j van4
JUE0 j ve0
JUE1 j ve1
JUE2 j ve2
JUE3 j ve3
JUE4 j ve4
JUN0 j vn0
JUN1 j vn1
JUN2 j vn2
JUN3 j vn3
JUN4 j vn4
KA0 k a0
KA1 k a1
KA2 k a2
KA3 k a3
KA4 k a4
KAI0 k ai0
KAI1 k ai1
KAI2 k ai2
KAI3 k ai3
KAI4 k ai4
KAN0 k an0
KAN1 k an1
KAN2 k an2
KAN3 k an3
KAN4 k an4
KANG0 k ang0
KANG1 k ang1
KANG2 k ang2
KANG3 k ang3
KANG4 k ang4
KAO0 k ao0
KAO1 k ao1
KAO2 k ao2
KAO3 k ao3
KAO4 k ao4
KE0 k e0
KE1 k e1
KE2 k e2
KE3 k e3
KE4 k e4
KEI0 k ei0
KEI1 k ei1
KEI2 k ei2
KEI3 k ei3
KEI4 k ei4
KEN0 k en0
KEN1 k en1
KEN2 k en2
KEN3 k en3
KEN4 k en4
KENG0 k eng0
KENG1 k eng1
KENG2 k eng2
KENG3 k eng3
KENG4 k eng4
KONG0 k ong0
KONG1 k ong1
KONG2 k ong2
KONG3 k ong3
KONG4 k ong4
KOU0 k ou0
KOU1 k ou1
KOU2 k ou2
KOU3 k ou3
KOU4 k ou4
KU0 k u0
KU1 k u1
KU2 k u2
KU3 k u3
KU4 k u4
KUA0 k ua0
KUA1 k ua1
KUA2 k ua2
KUA3 k ua3
KUA4 k ua4
KUAI0 k uai0
KUAI1 k uai1
KUAI2 k uai2
KUAI3 k uai3
KUAI4 k uai4
KUAN0 k uan0
KUAN1 k uan1
KUAN2 k uan2
KUAN3 k uan3
KUAN4 k uan4
KUANG0 k uang0
KUANG1 k uang1
KUANG2 k uang2
KUANG3 k uang3
KUANG4 k uang4
KUI0 k ui0
KUI1 k ui1
KUI2 k ui2
KUI3 k ui3
KUI4 k ui4
KUN0 k un0
KUN1 k un1
KUN2 k un2
KUN3 k un3
KUN4 k un4
KUO0 k uo0
KUO1 k uo1
KUO2 k uo2
KUO3 k uo3
KUO4 k uo4
LA0 l a0
LA1 l a1
LA2 l a2
LA3 l a3
LA4 l a4
LAI0 l ai0
LAI1 l ai1
LAI2 l ai2
LAI3 l ai3
LAI4 l ai4
LAN0 l an0
LAN1 l an1
LAN2 l an2
LAN3 l an3
LAN4 l an4
LANG0 l ang0
LANG1 l ang1
LANG2 l ang2
LANG3 l ang3
LANG4 l ang4
LAO0 l ao0
LAO1 l ao1
LAO2 l ao2
LAO3 l ao3
LAO4 l ao4
LE0 l e0
LE1 l e1
LE2 l e2
LE3 l e3
LE4 l e4
LEI0 l ei0
LEI1 l ei1
LEI2 l ei2
LEI3 l ei3
LEI4 l ei4
LENG0 l eng0
LENG1 l eng1
LENG2 l eng2
LENG3 l eng3
LENG4 l eng4
LI0 l i0
LI1 l i1
LI2 l i2
LI3 l i3
LI4 l i4
LIA0 l ia0
LIA1 l ia1
LIA2 l ia2
LIA3 l ia3
LIA4 l ia4
LIAN0 l ian0
LIAN1 l ian1
LIAN2 l ian2
LIAN3 l ian3
LIAN4 l ian4
LIANG0 l iang0
LIANG1 l iang1
LIANG2 l iang2
LIANG3 l iang3
LIANG4 l iang4
LIAO0 l iao0
LIAO1 l iao1
LIAO2 l iao2
LIAO3 l iao3
LIAO4 l iao4
LIE0 l ie0
LIE1 l ie1
LIE2 l ie2
LIE3 l ie3
LIE4 l ie4
LIN0 l in0
LIN1 l in1
LIN2 l in2
LIN3 l in3
LIN4 l in4
LING0 l ing0
LING1 l ing1
LING2 l ing2
LING3 l ing3
LING4 l ing4
LIU0 l iu0
LIU1 l iu1
LIU2 l iu2
LIU3 l iu3
LIU4 l iu4
LONG0 l ong0
LONG1 l ong1
LONG2 l ong2
LONG3 l ong3
LONG4 l ong4
LOU0 l ou0
LOU1 l ou1
LOU2 l ou2
LOU3 l ou3
LOU4 l ou4
LU0 l u0
LU1 l u1
LU2 l u2
LU3 l u3
LU4 l u4
LUAN0 l uan0
LUAN1 l uan1
LUAN2 l uan2
LUAN3 l uan3
LUAN4 l uan4
LUE0 l ve0
LUE1 l ve1
LUE2 l ve2
LUE3 l ve3
LUE4 l ve4
LVE0 l ve0
LVE1 l ve1
LVE2 l ve2
LVE3 l ve3
LVE4 l ve4
LUN0 l un0
LUN1 l un1
LUN2 l un2
LUN3 l un3
LUN4 l un4
LUO0 l uo0
LUO1 l uo1
LUO2 l uo2
LUO3 l uo3
LUO4 l uo4
LV0 l v0
LV1 l v1
LV2 l v2
LV3 l v3
LV4 l v4
MA0 m a0
MA1 m a1
MA2 m a2
MA3 m a3
MA4 m a4
MAI0 m ai0
MAI1 m ai1
MAI2 m ai2
MAI3 m ai3
MAI4 m ai4
MAN0 m an0
MAN1 m an1
MAN2 m an2
MAN3 m an3
MAN4 m an4
MANG0 m ang0
MANG1 m ang1
MANG2 m ang2
MANG3 m ang3
MANG4 m ang4
MAO0 m ao0
MAO1 m ao1
MAO2 m ao2
MAO3 m ao3
MAO4 m ao4
ME0 m e0
ME1 m e1
ME2 m e2
ME3 m e3
ME4 m e4
MEI0 m ei0
MEI1 m ei1
MEI2 m ei2
MEI3 m ei3
MEI4 m ei4
MEN0 m en0
MEN1 m en1
MEN2 m en2
MEN3 m en3
MEN4 m en4
MENG0 m eng0
MENG1 m eng1
MENG2 m eng2
MENG3 m eng3
MENG4 m eng4
MI0 m i0
MI1 m i1
MI2 m i2
MI3 m i3
MI4 m i4
MIAN0 m ian0
MIAN1 m ian1
MIAN2 m ian2
MIAN3 m ian3
MIAN4 m ian4
MIAO0 m iao0
MIAO1 m iao1
MIAO2 m iao2
MIAO3 m iao3
MIAO4 m iao4
MIE0 m ie0
MIE1 m ie1
MIE2 m ie2
MIE3 m ie3
MIE4 m ie4
MIN0 m in0
MIN1 m in1
MIN2 m in2
MIN3 m in3
MIN4 m in4
MING0 m ing0
MING1 m ing1
MING2 m ing2
MING3 m ing3
MING4 m ing4
MIU0 m iu0
MIU1 m iu1
MIU2 m iu2
MIU3 m iu3
MIU4 m iu4
MO0 m o0
MO1 m o1
MO2 m o2
MO3 m o3
MO4 m o4
MOU0 m ou0
MOU1 m ou1
MOU2 m ou2
MOU3 m ou3
MOU4 m ou4
MU0 m u0
MU1 m u1
MU2 m u2
MU3 m u3
MU4 m u4
NA0 n a0
NA1 n a1
NA2 n a2
NA3 n a3
NA4 n a4
NAI0 n ai0
NAI1 n ai1
NAI2 n ai2
NAI3 n ai3
NAI4 n ai4
NAN0 n an0
NAN1 n an1
NAN2 n an2
NAN3 n an3
NAN4 n an4
NANG0 n ang0
NANG1 n ang1
NANG2 n ang2
NANG3 n ang3
NANG4 n ang4
NAO0 n ao0
NAO1 n ao1
NAO2 n ao2
NAO3 n ao3
NAO4 n ao4
NE0 n e0
NE1 n e1
NE2 n e2
NE3 n e3
NE4 n e4
NEI0 n ei0
NEI1 n ei1
NEI2 n ei2
NEI3 n ei3
NEI4 n ei4
NEN0 n en0
NEN1 n en1
NEN2 n en2
NEN3 n en3
NEN4 n en4
NENG0 n eng0
NENG1 n eng1
NENG2 n eng2
NENG3 n eng3
NENG4 n eng4
NI0 n i0
NI1 n i1
NI2 n i2
NI3 n i3
NI4 n i4
NIAN0 n ian0
NIAN1 n ian1
NIAN2 n ian2
NIAN3 n ian3
NIAN4 n ian4
NIANG0 n iang0
NIANG1 n iang1
NIANG2 n iang2
NIANG3 n iang3
NIANG4 n iang4
NIAO0 n iao0
NIAO1 n iao1
NIAO2 n iao2
NIAO3 n iao3
NIAO4 n iao4
NIE0 n ie0
NIE1 n ie1
NIE2 n ie2
NIE3 n ie3
NIE4 n ie4
NIN0 n in0
NIN1 n in1
NIN2 n in2
NIN3 n in3
NIN4 n in4
NING0 n ing0
NING1 n ing1
NING2 n ing2
NING3 n ing3
NING4 n ing4
NIU0 n iu0
NIU1 n iu1
NIU2 n iu2
NIU3 n iu3
NIU4 n iu4
NONG0 n ong0
NONG1 n ong1
NONG2 n ong2
NONG3 n ong3
NONG4 n ong4
NU0 n u0
NU1 n u1
NU2 n u2
NU3 n u3
NU4 n u4
NUAN0 n uan0
NUAN1 n uan1
NUAN2 n uan2
NUAN3 n uan3
NUAN4 n uan4
NUE0 n ve0
NUE1 n ve1
NUE2 n ve2
NUE3 n ve3
NUE4 n ve4
NVE0 n ve0
NVE1 n ve1
NVE2 n ve2
NVE3 n ve3
NVE4 n ve4
NUO0 n uo0
NUO1 n uo1
NUO2 n uo2
NUO3 n uo3
NUO4 n uo4
NV0 n v0
NV1 n v1
NV2 n v2
NV3 n v3
NV4 n v4
O0 oo o0
O1 oo o1
O2 oo o2
O3 oo o3
O4 oo o4
OU0 oo ou0
OU1 oo ou1
OU2 oo ou2
OU3 oo ou3
OU4 oo ou4
PA0 p a0
PA1 p a1
PA2 p a2
PA3 p a3
PA4 p a4
PAI0 p ai0
PAI1 p ai1
PAI2 p ai2
PAI3 p ai3
PAI4 p ai4
PAN0 p an0
PAN1 p an1
PAN2 p an2
PAN3 p an3
PAN4 p an4
PANG0 p ang0
PANG1 p ang1
PANG2 p ang2
PANG3 p ang3
PANG4 p ang4
PAO0 p ao0
PAO1 p ao1
PAO2 p ao2
PAO3 p ao3
PAO4 p ao4
PEI0 p ei0
PEI1 p ei1
PEI2 p ei2
PEI3 p ei3
PEI4 p ei4
PEN0 p en0
PEN1 p en1
PEN2 p en2
PEN3 p en3
PEN4 p en4
PENG0 p eng0
PENG1 p eng1
PENG2 p eng2
PENG3 p eng3
PENG4 p eng4
PI0 p i0
PI1 p i1
PI2 p i2
PI3 p i3
PI4 p i4
PIAN0 p ian0
PIAN1 p ian1
PIAN2 p ian2
PIAN3 p ian3
PIAN4 p ian4
PIAO0 p iao0
PIAO1 p iao1
PIAO2 p iao2
PIAO3 p iao3
PIAO4 p iao4
PIE0 p ie0
PIE1 p ie1
PIE2 p ie2
PIE3 p ie3
PIE4 p ie4
PIN0 p in0
PIN1 p in1
PIN2 p in2
PIN3 p in3
PIN4 p in4
PING0 p ing0
PING1 p ing1
PING2 p ing2
PING3 p ing3
PING4 p ing4
PO0 p o0
PO1 p o1
PO2 p o2
PO3 p o3
PO4 p o4
POU0 p ou0
POU1 p ou1
POU2 p ou2
POU3 p ou3
POU4 p ou4
PU0 p u0
PU1 p u1
PU2 p u2
PU3 p u3
PU4 p u4
QI0 q i0
QI1 q i1
QI2 q i2
QI3 q i3
QI4 q i4
QIA0 q ia0
QIA1 q ia1
QIA2 q ia2
QIA3 q ia3
QIA4 q ia4
QIAN0 q ian0
QIAN1 q ian1
QIAN2 q ian2
QIAN3 q ian3
QIAN4 q ian4
QIANG0 q iang0
QIANG1 q iang1
QIANG2 q iang2
QIANG3 q iang3
QIANG4 q iang4
QIAO0 q iao0
QIAO1 q iao1
QIAO2 q iao2
QIAO3 q iao3
QIAO4 q iao4
QIE0 q ie0
QIE1 q ie1
QIE2 q ie2
QIE3 q ie3
QIE4 q ie4
QIN0 q in0
QIN1 q in1
QIN2 q in2
QIN3 q in3
QIN4 q in4
QING0 q ing0
QING1 q ing1
QING2 q ing2
QING3 q ing3
QING4 q ing4
QIONG0 q iong0
QIONG1 q iong1
QIONG2 q iong2
QIONG3 q iong3
QIONG4 q iong4
QIU0 q iu0
QIU1 q iu1
QIU2 q iu2
QIU3 q iu3
QIU4 q iu4
QU0 q v0
QU1 q v1
QU2 q v2
QU3 q v3
QU4 q v4
QUAN0 q van0
QUAN1 q van1
QUAN2 q van2
QUAN3 q van3
QUAN4 q van4
QUE0 q ve0
QUE1 q ve1
QUE2 q ve2
QUE3 q ve3
QUE4 q ve4
QUN0 q vn0
QUN1 q vn1
QUN2 q vn2
QUN3 q vn3
QUN4 q vn4
RAN0 r an0
RAN1 r an1
RAN2 r an2
RAN3 r an3
RAN4 r an4
RANG0 r ang0
RANG1 r ang1
RANG2 r ang2
RANG3 r ang3
RANG4 r ang4
RAO0 r ao0
RAO1 r ao1
RAO2 r ao2
RAO3 r ao3
RAO4 r ao4
RE0 r e0
RE1 r e1
RE2 r e2
RE3 r e3
RE4 r e4
REN0 r en0
REN1 r en1
REN2 r en2
REN3 r en3
REN4 r en4
RENG0 r eng0
RENG1 r eng1
RENG2 r eng2
RENG3 r eng3
RENG4 r eng4
RI0 r iz0
RI1 r iz1
RI2 r iz2
RI3 r iz3
RI4 r iz4
RONG0 r ong0
RONG1 r ong1
RONG2 r ong2
RONG3 r ong3
RONG4 r ong4
ROU0 r ou0
ROU1 r ou1
ROU2 r ou2
ROU3 r ou3
ROU4 r ou4
RU0 r u0
RU1 r u1
RU2 r u2
RU3 r u3
RU4 r u4
RUAN0 r uan0
RUAN1 r uan1
RUAN2 r uan2
RUAN3 r uan3
RUAN4 r uan4
RUI0 r ui0
RUI1 r ui1
RUI2 r ui2
RUI3 r ui3
RUI4 r ui4
RUN0 r un0
RUN1 r un1
RUN2 r un2
RUN3 r un3
RUN4 r un4
RUO0 r uo0
RUO1 r uo1
RUO2 r uo2
RUO3 r uo3
RUO4 r uo4
SA0 s a0
SA1 s a1
SA2 s a2
SA3 s a3
SA4 s a4
SAI0 s ai0
SAI1 s ai1
SAI2 s ai2
SAI3 s ai3
SAI4 s ai4
SAN0 s an0
SAN1 s an1
SAN2 s an2
SAN3 s an3
SAN4 s an4
SANG0 s ang0
SANG1 s ang1
SANG2 s ang2
SANG3 s ang3
SANG4 s ang4
SAO0 s ao0
SAO1 s ao1
SAO2 s ao2
SAO3 s ao3
SAO4 s ao4
SE0 s e0
SE1 s e1
SE2 s e2
SE3 s e3
SE4 s e4
SEN0 s en0
SEN1 s en1
SEN2 s en2
SEN3 s en3
SEN4 s en4
SENG0 s eng0
SENG1 s eng1
SENG2 s eng2
SENG3 s eng3
SENG4 s eng4
SHA0 sh a0
SHA1 sh a1
SHA2 sh a2
SHA3 sh a3
SHA4 sh a4
SHAI0 sh ai0
SHAI1 sh ai1
SHAI2 sh ai2
SHAI3 sh ai3
SHAI4 sh ai4
SHAN0 sh an0
SHAN1 sh an1
SHAN2 sh an2
SHAN3 sh an3
SHAN4 sh an4
SHANG0 sh ang0
SHANG1 sh ang1
SHANG2 sh ang2
SHANG3 sh ang3
SHANG4 sh ang4
SHAO0 sh ao0
SHAO1 sh ao1
SHAO2 sh ao2
SHAO3 sh ao3
SHAO4 sh ao4
SHE0 sh e0
SHE1 sh e1
SHE2 sh e2
SHE3 sh e3
SHE4 sh e4
SHEI0 sh ei0
SHEI1 sh ei1
SHEI2 sh ei2
SHEI3 sh ei3
SHEI4 sh ei4
SHEN0 sh en0
SHEN1 sh en1
SHEN2 sh en2
SHEN3 sh en3
SHEN4 sh en4
SHENG0 sh eng0
SHENG1 sh eng1
SHENG2 sh eng2
SHENG3 sh eng3
SHENG4 sh eng4
SHI0 sh ix0
SHI1 sh ix1
SHI2 sh ix2
SHI3 sh ix3
SHI4 sh ix4
SHOU0 sh ou0
SHOU1 sh ou1
SHOU2 sh ou2
SHOU3 sh ou3
SHOU4 sh ou4
SHU0 sh u0
SHU1 sh u1
SHU2 sh u2
SHU3 sh u3
SHU4 sh u4
SHUA0 sh ua0
SHUA1 sh ua1
SHUA2 sh ua2
SHUA3 sh ua3
SHUA4 sh ua4
SHUAI0 sh uai0
SHUAI1 sh uai1
SHUAI2 sh uai2
SHUAI3 sh uai3
SHUAI4 sh uai4
SHUAN0 sh uan0
SHUAN1 sh uan1
SHUAN2 sh uan2
SHUAN3 sh uan3
SHUAN4 sh uan4
SHUANG0 sh uang0
SHUANG1 sh uang1
SHUANG2 sh uang2
SHUANG3 sh uang3
SHUANG4 sh uang4
SHUI0 sh ui0
SHUI1 sh ui1
SHUI2 sh ui2
SHUI3 sh ui3
SHUI4 sh ui4
SHUN0 sh un0
SHUN1 sh un1
SHUN2 sh un2
SHUN3 sh un3
SHUN4 sh un4
SHUO0 sh uo0
SHUO1 sh uo1
SHUO2 sh uo2
SHUO3 sh uo3
SHUO4 sh uo4
SI0 s iy0
SI1 s iy1
SI2 s iy2
SI3 s iy3
SI4 s iy4
SONG0 s ong0
SONG1 s ong1
SONG2 s ong2
SONG3 s ong3
SONG4 s ong4
SOU0 s ou0
SOU1 s ou1
SOU2 s ou2
SOU3 s ou3
SOU4 s ou4
SU0 s u0
SU1 s u1
SU2 s u2
SU3 s u3
SU4 s u4
SUAN0 s uan0
SUAN1 s uan1
SUAN2 s uan2
SUAN3 s uan3
SUAN4 s uan4
SUI0 s ui0
SUI1 s ui1
SUI2 s ui2
SUI3 s ui3
SUI4 s ui4
SUN0 s un0
SUN1 s un1
SUN2 s un2
SUN3 s un3
SUN4 s un4
SUO0 s uo0
SUO1 s uo1
SUO2 s uo2
SUO3 s uo3
SUO4 s uo4
TA0 t a0
TA1 t a1
TA2 t a2
TA3 t a3
TA4 t a4
TAI0 t ai0
TAI1 t ai1
TAI2 t ai2
TAI3 t ai3
TAI4 t ai4
TAN0 t an0
TAN1 t an1
TAN2 t an2
TAN3 t an3
TAN4 t an4
TANG0 t ang0
TANG1 t ang1
TANG2 t ang2
TANG3 t ang3
TANG4 t ang4
TAO0 t ao0
TAO1 t ao1
TAO2 t ao2
TAO3 t ao3
TAO4 t ao4
TE0 t e0
TE1 t e1
TE2 t e2
TE3 t e3
TE4 t e4
TENG0 t eng0
TENG1 t eng1
TENG2 t eng2
TENG3 t eng3
TENG4 t eng4
TI0 t i0
TI1 t i1
TI2 t i2
TI3 t i3
TI4 t i4
TIAN0 t ian0
TIAN1 t ian1
TIAN2 t ian2
TIAN3 t ian3
TIAN4 t ian4
TIAO0 t iao0
TIAO1 t iao1
TIAO2 t iao2
TIAO3 t iao3
TIAO4 t iao4
TIE0 t ie0
TIE1 t ie1
TIE2 t ie2
TIE3 t ie3
TIE4 t ie4
TING0 t ing0
TING1 t ing1
TING2 t ing2
TING3 t ing3
TING4 t ing4
TONG0 t ong0
TONG1 t ong1
TONG2 t ong2
TONG3 t ong3
TONG4 t ong4
TOU0 t ou0
TOU1 t ou1
TOU2 t ou2
TOU3 t ou3
TOU4 t ou4
TU0 t u0
TU1 t u1
TU2 t u2
TU3 t u3
TU4 t u4
TUAN0 t uan0
TUAN1 t uan1
TUAN2 t uan2
TUAN3 t uan3
TUAN4 t uan4
TUI0 t ui0
TUI1 t ui1
TUI2 t ui2
TUI3 t ui3
TUI4 t ui4
TUN0 t un0
TUN1 t un1
TUN2 t un2
TUN3 t un3
TUN4 t un4
TUO0 t uo0
TUO1 t uo1
TUO2 t uo2
TUO3 t uo3
TUO4 t uo4
WA0 uu ua0
WA1 uu ua1
WA2 uu ua2
WA3 uu ua3
WA4 uu ua4
WAI0 uu uai0
WAI1 uu uai1
WAI2 uu uai2
WAI3 uu uai3
WAI4 uu uai4
WAN0 uu uan0
WAN1 uu uan1
WAN2 uu uan2
WAN3 uu uan3
WAN4 uu uan4
WANG0 uu uang0
WANG1 uu uang1
WANG2 uu uang2
WANG3 uu uang3
WANG4 uu uang4
WEI0 uu ui0
WEI1 uu ui1
WEI2 uu ui2
WEI3 uu ui3
WEI4 uu ui4
WEN0 uu un0
WEN1 uu un1
WEN2 uu un2
WEN3 uu un3
WEN4 uu un4
WENG0 uu ueng0
WENG1 uu ueng1
WENG2 uu ueng2
WENG3 uu ueng3
WENG4 uu ueng4
WO0 uu uo0
WO1 uu uo1
WO2 uu uo2
WO3 uu uo3
WO4 uu uo4
WU0 uu u0
WU1 uu u1
WU2 uu u2
WU3 uu u3
WU4 uu u4
XI0 x i0
XI1 x i1
XI2 x i2
XI3 x i3
XI4 x i4
XIA0 x ia0
XIA1 x ia1
XIA2 x ia2
XIA3 x ia3
XIA4 x ia4
XIAN0 x ian0
XIAN1 x ian1
XIAN2 x ian2
XIAN3 x ian3
XIAN4 x ian4
XIANG0 x iang0
XIANG1 x iang1
XIANG2 x iang2
XIANG3 x iang3
XIANG4 x iang4
XIAO0 x iao0
XIAO1 x iao1
XIAO2 x iao2
XIAO3 x iao3
XIAO4 x iao4
XIE0 x ie0
XIE1 x ie1
XIE2 x ie2
XIE3 x ie3
XIE4 x ie4
XIN0 x in0
XIN1 x in1
XIN2 x in2
XIN3 x in3
XIN4 x in4
XING0 x ing0
XING1 x ing1
XING2 x ing2
XING3 x ing3
XING4 x ing4
XIONG0 x iong0
XIONG1 x iong1
XIONG2 x iong2
XIONG3 x iong3
XIONG4 x iong4
XIU0 x iu0
XIU1 x iu1
XIU2 x iu2
XIU3 x iu3
XIU4 x iu4
XU0 x v0
XU1 x v1
XU2 x v2
XU3 x v3
XU4 x v4
XUAN0 x van0
XUAN1 x van1
XUAN2 x van2
XUAN3 x van3
XUAN4 x van4
XUE0 x ve0
XUE1 x ve1
XUE2 x ve2
XUE3 x ve3
XUE4 x ve4
XUN0 x vn0
XUN1 x vn1
XUN2 x vn2
XUN3 x vn3
XUN4 x vn4
YA0 ii ia0
YA1 ii ia1
YA2 ii ia2
YA3 ii ia3
YA4 ii ia4
YAN0 ii ian0
YAN1 ii ian1
YAN2 ii ian2
YAN3 ii ian3
YAN4 ii ian4
YANG0 ii iang0
YANG1 ii iang1
YANG2 ii iang2
YANG3 ii iang3
YANG4 ii iang4
YAO0 ii iao0
YAO1 ii iao1
YAO2 ii iao2
YAO3 ii iao3
YAO4 ii iao4
YE0 ii ie0
YE1 ii ie1
YE2 ii ie2
YE3 ii ie3
YE4 ii ie4
YI0 ii i0
YI1 ii i1
YI2 ii i2
YI3 ii i3
YI4 ii i4
YIN0 ii in0
YIN1 ii in1
YIN2 ii in2
YIN3 ii in3
YIN4 ii in4
YING0 ii ing0
YING1 ii ing1
YING2 ii ing2
YING3 ii ing3
YING4 ii ing4
YO0 ii ou0
YO1 ii ou1
YO2 ii ou2
YO3 ii ou3
YO4 ii ou4
YONG0 ii iong0
YONG1 ii iong1
YONG2 ii iong2
YONG3 ii iong3
YONG4 ii iong4
YOU0 ii iu0
YOU1 ii iu1
YOU2 ii iu2
YOU3 ii iu3
YOU4 ii iu4
YU0 vv v0
YU1 vv v1
YU2 vv v2
YU3 vv v3
YU4 vv v4
YUAN0 vv van0
YUAN1 vv van1
YUAN2 vv van2
YUAN3 vv van3
YUAN4 vv van4
YUE0 vv ve0
YUE1 vv ve1
YUE2 vv ve2
YUE3 vv ve3
YUE4 vv ve4
YUN0 vv vn0
YUN1 vv vn1
YUN2 vv vn2
YUN3 vv vn3
YUN4 vv vn4
YUO0 ii ou0
YUO1 ii ou1
YUO2 ii ou2
YUO3 ii ou3
YUO4 ii ou4
ZA0 z a0
ZA1 z a1
ZA2 z a2
ZA3 z a3
ZA4 z a4
ZAI0 z ai0
ZAI1 z ai1
ZAI2 z ai2
ZAI3 z ai3
ZAI4 z ai4
ZAN0 z an0
ZAN1 z an1
ZAN2 z an2
ZAN3 z an3
ZAN4 z an4
ZANG0 z ang0
ZANG1 z ang1
ZANG2 z ang2
ZANG3 z ang3
ZANG4 z ang4
ZAO0 z ao0
ZAO1 z ao1
ZAO2 z ao2
ZAO3 z ao3
ZAO4 z ao4
ZE0 z e0
ZE1 z e1
ZE2 z e2
ZE3 z e3
ZE4 z e4
ZEI0 z ei0
ZEI1 z ei1
ZEI2 z ei2
ZEI3 z ei3
ZEI4 z ei4
ZEN0 z en0
ZEN1 z en1
ZEN2 z en2
ZEN3 z en3
ZEN4 z en4
ZENG0 z eng0
ZENG1 z eng1
ZENG2 z eng2
ZENG3 z eng3
ZENG4 z eng4
ZHA0 zh a0
ZHA1 zh a1
ZHA2 zh a2
ZHA3 zh a3
ZHA4 zh a4
ZHAI0 zh ai0
ZHAI1 zh ai1
ZHAI2 zh ai2
ZHAI3 zh ai3
ZHAI4 zh ai4
ZHAN0 zh an0
ZHAN1 zh an1
ZHAN2 zh an2
ZHAN3 zh an3
ZHAN4 zh an4
ZHANG0 zh ang0
ZHANG1 zh ang1
ZHANG2 zh ang2
ZHANG3 zh ang3
ZHANG4 zh ang4
ZHAO0 zh ao0
ZHAO1 zh ao1
ZHAO2 zh ao2
ZHAO3 zh ao3
ZHAO4 zh ao4
ZHE0 zh e0
ZHE1 zh e1
ZHE2 zh e2
ZHE3 zh e3
ZHE4 zh e4
ZHEI0 zh ei0
ZHEI1 zh ei1
ZHEI2 zh ei2
ZHEI3 zh ei3
ZHEI4 zh ei4
ZHEN0 zh en0
ZHEN1 zh en1
ZHEN2 zh en2
ZHEN3 zh en3
ZHEN4 zh en4
ZHENG0 zh eng0
ZHENG1 zh eng1
ZHENG2 zh eng2
ZHENG3 zh eng3
ZHENG4 zh eng4
ZHI0 zh ix0
ZHI1 zh ix1
ZHI2 zh ix2
ZHI3 zh ix3
ZHI4 zh ix4
ZHONG0 zh ong0
ZHONG1 zh ong1
ZHONG2 zh ong2
ZHONG3 zh ong3
ZHONG4 zh ong4
ZHOU0 zh ou0
ZHOU1 zh ou1
ZHOU2 zh ou2
ZHOU3 zh ou3
ZHOU4 zh ou4
ZHU0 zh u0
ZHU1 zh u1
ZHU2 zh u2
ZHU3 zh u3
ZHU4 zh u4
ZHUA0 zh ua0
ZHUA1 zh ua1
ZHUA2 zh ua2
ZHUA3 zh ua3
ZHUA4 zh ua4
ZHUAI0 zh uai0
ZHUAI1 zh uai1
ZHUAI2 zh uai2
ZHUAI3 zh uai3
ZHUAI4 zh uai4
ZHUAN0 zh uan0
ZHUAN1 zh uan1
ZHUAN2 zh uan2
ZHUAN3 zh uan3
ZHUAN4 zh uan4
ZHUANG0 zh uang0
ZHUANG1 zh uang1
ZHUANG2 zh uang2
ZHUANG3 zh uang3
ZHUANG4 zh uang4
ZHUI0 zh ui0
ZHUI1 zh ui1
ZHUI2 zh ui2
ZHUI3 zh ui3
ZHUI4 zh ui4
ZHUN0 zh un0
ZHUN1 zh un1
ZHUN2 zh un2
ZHUN3 zh un3
ZHUN4 zh un4
ZHUO0 zh uo0
ZHUO1 zh uo1
ZHUO2 zh uo2
ZHUO3 zh uo3
ZHUO4 zh uo4
ZI0 z iy0
ZI1 z iy1
ZI2 z iy2
ZI3 z iy3
ZI4 z iy4
ZONG0 z ong0
ZONG1 z ong1
ZONG2 z ong2
ZONG3 z ong3
ZONG4 z ong4
ZOU0 z ou0
ZOU1 z ou1
ZOU2 z ou2
ZOU3 z ou3
ZOU4 z ou4
ZU0 z u0
ZU1 z u1
ZU2 z u2
ZU3 z u3
ZU4 z u4
ZUAN0 z uan0
ZUAN1 z uan1
ZUAN2 z uan2
ZUAN3 z uan3
ZUAN4 z uan4
ZUI0 z ui0
ZUI1 z ui1
ZUI2 z ui2
ZUI3 z ui3
ZUI4 z ui4
ZUN0 z un0
ZUN1 z un1
ZUN2 z un2
ZUN3 z un3
ZUN4 z un4
ZUO0 z uo0
ZUO1 z uo1
ZUO2 z uo2
ZUO3 z uo3
ZUO4 z uo4
EI0 ee ei0
EI1 ee ei1
EI2 ee ei2
EI3 ee ei3
EI4 ee ei4
TEI0 t ei0
TEI1 t ei1
TEI2 t ei2
TEI3 t ei3
TEI4 t ei4
HNG0 ee eng0
HNG1 ee eng1
HNG2 ee eng2
HNG3 ee eng3
HNG4 ee eng4
LO0 l o0
LO1 l o1
LO2 l o2
LO3 l o3
LO4 l o4
N0 ee en0
N1 ee en1
N2 ee en2
N3 ee en3
N4 ee en4
NG0 ee eng0
NG1 ee eng1
NG2 ee eng2
NG3 ee eng3
NG4 ee eng4
NOU0 n ao0
NOU1 n ao1
NOU2 n ao2
NOU3 n ao3
NOU4 n ao4
SEI0 s ei0
SEI1 s ei1
SEI2 s ei2
SEI3 s ei3
SEI4 s ei4
A5 aa a5
AI5 aa ai5
AN5 aa an5
ANG5 aa ang5
AO5 aa ao5
BA5 b a5
BAI5 b ai5
BAN5 b an5
BANG5 b ang5
BAO5 b ao5
BEI5 b ei5
BEN5 b en5
BENG5 b eng5
BI5 b i5
BIAN5 b ian5
BIAO5 b iao5
BIE5 b ie5
BIN5 b in5
BING5 b ing5
BO5 b o5
BU5 b u5
CA5 c a5
CAI5 c ai5
CAN5 c an5
CANG5 c ang5
CAO5 c ao5
CE5 c e5
CEN5 c en5
CENG5 c eng5
CHA5 ch a5
CHAI5 ch ai5
CHAN5 ch an5
CHANG5 ch ang5
CHAO5 ch ao5
CHE5 ch e5
CHEN5 ch en5
CHENG5 ch eng5
CHI5 ch ix5
CHONG5 ch ong5
CHOU5 ch ou5
CHU5 ch u5
CHUAI5 ch uai5
CHUAN5 ch uan5
CHUANG5 ch uang5
CHUI5 ch ui5
CHUN5 ch un5
CHUO5 ch uo5
CI5 c iy5
CONG5 c ong5
COU5 c ou5
CU5 c u5
CUAN5 c uan5
CUI5 c ui5
CUN5 c un5
CUO5 c uo5
DA5 d a5
DAI5 d ai5
DAN5 d an5
DANG5 d ang5
DAO5 d ao5
DE5 d e5
DEI5 d ei5
DEN5 d en5
DENG5 d eng5
DI5 d i5
DIA5 d ia5
DIAN5 d ian5
DIAO5 d iao5
DIE5 d ie5
DING5 d ing5
DIU5 d iu5
DONG5 d ong5
DOU5 d ou5
DU5 d u5
DUAN5 d uan5
DUI5 d ui5
DUN5 d un5
DUO5 d uo5
E5 ee e5
EN5 ee en5
ER5 ee er5
FA5 f a5
FAN5 f an5
FANG5 f ang5
FEI5 f ei5
FEN5 f en5
FENG5 f eng5
FO5 f o5
FOU5 f ou5
FU5 f u5
GA5 g a5
GAI5 g ai5
GAN5 g an5
GANG5 g ang5
GAO5 g ao5
GE5 g e5
GEI5 g ei5
GEN5 g en5
GENG5 g eng5
GONG5 g ong5
GOU5 g ou5
GU5 g u5
GUA5 g ua5
GUAI5 g uai5
GUAN5 g uan5
GUANG5 g uang5
GUI5 g ui5
GUN5 g un5
GUO5 g uo5
HA5 h a5
HAI5 h ai5
HAN5 h an5
HANG5 h ang5
HAO5 h ao5
HE5 h e5
HEI5 h ei5
HEN5 h en5
HENG5 h eng5
HONG5 h ong5
HOU5 h ou5
HU5 h u5
HUA5 h ua5
HUAI5 h uai5
HUAN5 h uan5
HUANG5 h uang5
HUI5 h ui5
HUN5 h un5
HUO5 h uo5
JI5 j i5
JIA5 j ia5
JIAN5 j ian5
JIANG5 j iang5
JIAO5 j iao5
JIE5 j ie5
JIN5 j in5
JING5 j ing5
JIONG5 j iong5
JIU5 j iu5
JU5 j v5
JUAN5 j van5
JUE5 j ve5
JUN5 j vn5
KA5 k a5
KAI5 k ai5
KAN5 k an5
KANG5 k ang5
KAO5 k ao5
KE5 k e5
KEI5 k ei5
KEN5 k en5
KENG5 k eng5
KONG5 k ong5
KOU5 k ou5
KU5 k u5
KUA5 k ua5
KUAI5 k uai5
KUAN5 k uan5
KUANG5 k uang5
KUI5 k ui5
KUN5 k un5
KUO5 k uo5
LA5 l a5
LAI5 l ai5
LAN5 l an5
LANG5 l ang5
LAO5 l ao5
LE5 l e5
LEI5 l ei5
LENG5 l eng5
LI5 l i5
LIA5 l ia5
LIAN5 l ian5
LIANG5 l iang5
LIAO5 l iao5
LIE5 l ie5
LIN5 l in5
LING5 l ing5
LIU5 l iu5
LONG5 l ong5
LOU5 l ou5
LU5 l u5
LUAN5 l uan5
LUE5 l ve5
LVE5 l ve5
LUN5 l un5
LUO5 l uo5
LV5 l v5
MA5 m a5
MAI5 m ai5
MAN5 m an5
MANG5 m ang5
MAO5 m ao5
ME5 m e5
MEI5 m ei5
MEN5 m en5
MENG5 m eng5
MI5 m i5
MIAN5 m ian5
MIAO5 m iao5
MIE5 m ie5
MIN5 m in5
MING5 m ing5
MIU5 m iu5
MO5 m o5
MOU5 m ou5
MU5 m u5
NA5 n a5
NAI5 n ai5
NAN5 n an5
NANG5 n ang5
NAO5 n ao5
NE5 n e5
NEI5 n ei5
NEN5 n en5
NENG5 n eng5
NI5 n i5
NIAN5 n ian5
NIANG5 n iang5
NIAO5 n iao5
NIE5 n ie5
NIN5 n in5
NING5 n ing5
NIU5 n iu5
NONG5 n ong5
NU5 n u5
NUAN5 n uan5
NUE5 n ve5
NVE5 n ve5
NUO5 n uo5
NV5 n v5
O5 oo o5
OU5 oo ou5
PA5 p a5
PAI5 p ai5
PAN5 p an5
PANG5 p ang5
PAO5 p ao5
PEI5 p ei5
PEN5 p en5
PENG5 p eng5
PI5 p i5
PIAN5 p ian5
PIAO5 p iao5
PIE5 p ie5
PIN5 p in5
PING5 p ing5
PO5 p o5
POU5 p ou5
PU5 p u5
QI5 q i5
QIA5 q ia5
QIAN5 q ian5
QIANG5 q iang5
QIAO5 q iao5
QIE5 q ie5
QIN5 q in5
QING5 q ing5
QIONG5 q iong5
QIU5 q iu5
QU5 q v5
QUAN5 q van5
QUE5 q ve5
QUN5 q vn5
RAN5 r an5
RANG5 r ang5
RAO5 r ao5
RE5 r e5
REN5 r en5
RENG5 r eng5
RI5 r iz5
RONG5 r ong5
ROU5 r ou5
RU5 r u5
RUAN5 r uan5
RUI5 r ui5
RUN5 r un5
RUO5 r uo5
SA5 s a5
SAI5 s ai5
SAN5 s an5
SANG5 s ang5
SAO5 s ao5
SE5 s e5
SEN5 s en5
SENG5 s eng5
SHA5 sh a5
SHAI5 sh ai5
SHAN5 sh an5
SHANG5 sh ang5
SHAO5 sh ao5
SHE5 sh e5
SHEI5 sh ei5
SHEN5 sh en5
SHENG5 sh eng5
SHI5 sh ix5
SHOU5 sh ou5
SHU5 sh u5
SHUA5 sh ua5
SHUAI5 sh uai5
SHUAN5 sh uan5
SHUANG5 sh uang5
SHUI5 sh ui5
SHUN5 sh un5
SHUO5 sh uo5
SI5 s iy5
SONG5 s ong5
SOU5 s ou5
SU5 s u5
SUAN5 s uan5
SUI5 s ui5
SUN5 s un5
SUO5 s uo5
TA5 t a5
TAI5 t ai5
TAN5 t an5
TANG5 t ang5
TAO5 t ao5
TE5 t e5
TENG5 t eng5
TI5 t i5
TIAN5 t ian5
TIAO5 t iao5
TIE5 t ie5
TING5 t ing5
TONG5 t ong5
TOU5 t ou5
TU5 t u5
TUAN5 t uan5
TUI5 t ui5
TUN5 t un5
TUO5 t uo5
WA5 uu ua5
WAI5 uu uai5
WAN5 uu uan5
WANG5 uu uang5
WEI5 uu ui5
WEN5 uu un5
WENG5 uu ueng5
WO5 uu uo5
WU5 uu u5
XI5 x i5
XIA5 x ia5
XIAN5 x ian5
XIANG5 x iang5
XIAO5 x iao5
XIE5 x ie5
XIN5 x in5
XING5 x ing5
XIONG5 x iong5
XIU5 x iu5
XU5 x v5
XUAN5 x van5
XUE5 x ve5
XUN5 x vn5
YA5 ii ia5
YAN5 ii ian5
YANG5 ii iang5
YAO5 ii iao5
YE5 ii ie5
YI5 ii i5
YIN5 ii in5
YING5 ii ing5
YO5 ii ou5
YONG5 ii iong5
YOU5 ii iu5
YU5 vv v5
YUAN5 vv van5
YUE5 vv ve5
YUN5 vv vn5
YUO5 ii ou5
ZA5 z a5
ZAI5 z ai5
ZAN5 z an5
ZANG5 z ang5
ZAO5 z ao5
ZE5 z e5
ZEI5 z ei5
ZEN5 z en5
ZENG5 z eng5
ZHA5 zh a5
ZHAI5 zh ai5
ZHAN5 zh an5
ZHANG5 zh ang5
ZHAO5 zh ao5
ZHE5 zh e5
ZHEI5 zh ei5
ZHEN5 zh en5
ZHENG5 zh eng5
ZHI5 zh ix5
ZHONG5 zh ong5
ZHOU5 zh ou5
ZHU5 zh u5
ZHUA5 zh ua5
ZHUAI5 zh uai5
ZHUAN5 zh uan5
ZHUANG5 zh uang5
ZHUI5 zh ui5
ZHUN5 zh un5
ZHUO5 zh uo5
ZI5 z iy5
ZONG5 z ong5
ZOU5 z ou5
ZU5 z u5
ZUAN5 z uan5
ZUI5 z ui5
ZUN5 z un5
ZUO5 z uo5
EI5 ee ei5
TEI5 t ei5
HNG5 ee eng5
LO5 l o5
N5 ee en5
NG5 ee eng5
NOU5 n ao5
SEI5 s ei5
\ No newline at end of file
#! /usr/bin/env bash
stage=-1
stop_stage=100
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR}
LEXICON_NAME=$1
# download data, generate manifests
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
python3 ${TARGET_DIR}/thchs30/thchs30.py \
--manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/thchs30"
if [ $? -ne 0 ]; then
echo "Prepare THCHS-30 failed. Terminated."
exit 1
fi
fi
# dump manifest to data/
python3 ${MAIN_ROOT}/utils/dump_manifest.py --manifest-path=data/manifest.train --output-dir=data
# copy files to data/dict to gen word.lexicon
cp ${TARGET_DIR}/thchs30/data_thchs30/lm_word/lexicon.txt data/dict/lm_word_lexicon_1
cp ${TARGET_DIR}/thchs30/resource/dict/lexicon.txt data/dict/lm_word_lexicon_2
# copy phone.lexicon to data/dict
cp ${TARGET_DIR}/thchs30/data_thchs30/lm_phone/lexicon.txt data/dict/phone.lexicon
# gen word.lexicon
python local/gen_word2phone.py --root-dir=data/dict --output-dir=data/dict
# reorganize dataset for MFA
if [ ! -d $EXP_DIR/thchs30_corpus ]; then
echo "reorganizing thchs30 corpus..."
python local/reorganize_thchs30.py --root-dir=data --output-dir=data/thchs30_corpus --script-type=$LEXICON_NAME
echo "reorganization done."
fi
echo "THCHS-30 data preparation done."
exit 0
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gen Chinese characters to THCHS30-30 phone lexicon using THCHS30-30's lexicon
file1: THCHS-30/data_thchs30/lm_word/lexicon.txt
file2: THCHS-30/resource/dict/lexicon.txt
"""
import argparse
from collections import defaultdict
from pathlib import Path
from typing import Union
# key: (cn, ('ee', 'er4')),value: count
cn_phones_counter = defaultdict(int)
# key: cn, value: list of (phones, num)
cn_counter = defaultdict(list)
# key: cn, value: list of (phones, probabilities)
cn_counter_p = defaultdict(list)
def is_Chinese(ch):
if '\u4e00' <= ch <= '\u9fff':
return True
return False
def proc_line(line):
line = line.strip()
if is_Chinese(line[0]):
line_list = line.split()
cn_list = line_list[0]
phone_list = line_list[1:]
if len(cn_list) == len(phone_list) / 2:
new_phone_list = [(phone_list[i], phone_list[i + 1])
for i in range(0, len(phone_list), 2)]
assert len(cn_list) == len(new_phone_list)
for idx, cn in enumerate(cn_list):
phones = new_phone_list[idx]
cn_phones_counter[(cn, phones)] += 1
def gen_lexicon(root_dir: Union[str, Path], output_dir: Union[str, Path]):
root_dir = Path(root_dir).expanduser()
output_dir = Path(output_dir).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
file1 = root_dir / "lm_word_lexicon_1"
file2 = root_dir / "lm_word_lexicon_2"
write_file = output_dir / "word.lexicon"
with open(file1, "r") as f1:
for line in f1:
proc_line(line)
with open(file2, "r") as f2:
for line in f2:
proc_line(line)
for key in cn_phones_counter:
cn = key[0]
cn_counter[cn].append((key[1], cn_phones_counter[key]))
for key in cn_counter:
phone_count_list = cn_counter[key]
count_sum = sum([x[1] for x in phone_count_list])
for item in phone_count_list:
p = item[1] / count_sum
p = round(p, 2)
if p > 0:
cn_counter_p[key].append((item[0], p))
with open(write_file, "w") as wf:
for key in cn_counter_p:
phone_p_list = cn_counter_p[key]
for item in phone_p_list:
phones, p = item
wf.write(key + " " + str(p) + " " + " ".join(phones) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Gen Chinese characters to phone lexicon for THCHS-30 dataset"
)
parser.add_argument(
"--root-dir", type=str, help="dir to thchs30 lm_word_lexicons")
parser.add_argument("--output-dir", type=str, help="path to save outputs")
args = parser.parse_args()
gen_lexicon(args.root_dir, args.output_dir)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Recorganize THCHS-30 for MFA
read manifest.train from root-dir
Link *.wav to output-dir
dump *.lab from manifest.train, such as: text、syllable and phone
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
"""
import argparse
import os
from pathlib import Path
from typing import Union
def link_wav(root_dir: Union[str, Path], output_dir: Union[str, Path]):
wav_scp_path = root_dir / 'wav.scp'
with open(wav_scp_path, 'r') as rf:
for line in rf:
utt, feat = line.strip().split()
wav_path = feat
wav_name = wav_path.split("/")[-1]
new_wav_path = output_dir / wav_name
os.symlink(wav_path, new_wav_path)
def write_lab(root_dir: Union[str, Path],
output_dir: Union[str, Path],
script_type='phone'):
# script_type can in {'word', 'syllable', 'phone'}
json_name = 'text.' + script_type
json_path = root_dir / json_name
with open(json_path, 'r') as rf:
for line in rf:
line = line.strip().split()
utt_id = line[0]
context = ' '.join(line[1:])
transcript_name = utt_id + '.lab'
transcript_path = output_dir / transcript_name
with open(transcript_path, 'wt') as wf:
if script_type == 'word':
# add space between chinese char
context = ''.join([f + ' ' for f in context])[:-1]
wf.write(context + "\n")
def reorganize_thchs30(root_dir: Union[str, Path],
output_dir: Union[str, Path]=None,
script_type='phone'):
root_dir = Path(root_dir).expanduser()
output_dir = Path(output_dir).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
link_wav(root_dir, output_dir)
write_lab(root_dir, output_dir, script_type)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Reorganize THCHS-30 dataset for MFA")
parser.add_argument("--root-dir", type=str, help="path to thchs30 dataset.")
parser.add_argument(
"--output-dir",
type=str,
help="path to save outputs(audio and transcriptions)")
parser.add_argument(
"--script-type",
type=str,
default="phone",
help="type of lab ('word'/'syllable'/'phone')")
args = parser.parse_args()
reorganize_thchs30(args.root_dir, args.output_dir, args.script_type)
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
# MFA is in tools
export PATH=${MAIN_ROOT}/tools/montreal-forced-aligner/bin:$PATH
\ No newline at end of file
#!/bin/bash
set -e
source path.sh
stage=0
stop_stage=100
EXP_DIR=exp
# LEXICON_NAME in {'phone', 'syllable', 'word'}
LEXICON_NAME='phone'
# set MFA num_jobs as half of machine's cpu core number
NUM_JOBS=$((`nproc`/2))
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
# download dataset、unzip and generate manifest
# gen lexicon relink gen dump
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
bash ./local/data.sh $LEXICON_NAME|| exit -1
fi
# run MFA
if [ ! -d "$EXP_DIR/thchs30_alignment" ]; then
echo "Start MFA training..."
mfa_train_and_align data/thchs30_corpus data/dict/$LEXICON_NAME.lexicon $EXP_DIR/thchs30_alignment -o $EXP_DIR/thchs30_model --clean --verbose --temp_directory exp/.mfa_train_and_align --num_jobs $NUM_JOBS
echo "training done! \nresults: $EXP_DIR/thchs30_alignment \nmodel: $EXP_DIR/thchs30_model\n"
fi
# TIMIT
* s1 u2 model with phone unit
# TIMIT
Results will be organized and updated soon.
[
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
},
{
"type": "speed",
"params": {
"min_speed_rate": 0.9,
"max_speed_rate": 1.1,
"num_rates": 3
},
"prob": 0.0
},
{
"type": "specaug",
"params": {
"F": 10,
"T": 50,
"n_freq_masks": 2,
"n_time_masks": 2,
"p": 1.0,
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
},
"prob": 1.0
}
]
faks0
fdac1
fjem0
mgwt0
mjar0
mmdb1
mmdm2
mpdf0
fcmh0
fkms0
mbdg0
mbwm0
mcsh0
fadg0
fdms0
fedw0
mgjf0
mglb0
mrtk0
mtaa0
mtdt0
mthc0
mwjg0
fnmr0
frew0
fsem0
mbns0
mmjr0
mdls0
mdlf0
mdvc0
mers0
fmah0
fdrw0
mrcs0
mrjm4
fcal1
mmwh0
fjsj0
majc0
mjsw0
mreb0
fgjd0
fjmg0
mroa0
mteb0
mjfc0
mrjr0
fmml0
mrws1
\ No newline at end of file
aa aa aa
ae ae ae
ah ah ah
ao ao aa
aw aw aw
ax ax ah
ax-h ax ah
axr er er
ay ay ay
b b b
bcl vcl sil
ch ch ch
d d d
dcl vcl sil
dh dh dh
dx dx dx
eh eh eh
el el l
em m m
en en n
eng ng ng
epi epi sil
er er er
ey ey ey
f f f
g g g
gcl vcl sil
h# sil sil
hh hh hh
hv hh hh
ih ih ih
ix ix ih
iy iy iy
jh jh jh
k k k
kcl cl sil
l l l
m m m
n n n
ng ng ng
nx n n
ow ow ow
oy oy oy
p p p
pau sil sil
pcl cl sil
q
r r r
s s s
sh sh sh
t t t
tcl cl sil
th th th
uh uh uh
uw uw uw
ux uw uw
v v v
w w w
y y y
z z z
zh zh sh
\ No newline at end of file
mdab0
mwbt0
felc0
mtas1
mwew0
fpas0
mjmp0
mlnt0
fpkt0
mlll0
mtls0
fjlm0
mbpm0
mklt0
fnlp0
mcmj0
mjdh0
fmgd0
mgrt0
mnjm0
fdhc0
mjln0
mpam0
fmld0
\ No newline at end of file
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.5 # second
max_input_len: 30.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 100.0
collator:
vocab_filepath: data/vocab.txt
unit_type: "word"
mean_std_filepath: ""
augmentation_config: ""
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file_type: "json"
# encoder related
encoder: transformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
training:
n_epoch: 120
accum_grad: 2
global_grad_clip: 5.0
optim: adam
optim_conf:
lr: 0.002
weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 400
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 64
error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 10
cutoff_prob: 1.0
cutoff_top_n: 0
num_proc_bsearch: 8
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0
#!/bin/bash
stage=-1
stop_stage=100
unit_type=word
TIMIT_path=
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR}
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data, generate manifests
python3 ${TARGET_DIR}/timit/timit_kaldi_standard_split.py \
--manifest_prefix="data/manifest" \
--src="data/local" \
if [ $? -ne 0 ]; then
echo "Prepare TIMIT failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# build vocabulary
python3 ${MAIN_ROOT}/utils/build_vocab.py \
--unit_type ${unit_type} \
--count_threshold=0 \
--vocab_path="data/vocab.txt" \
--manifest_paths="data/manifest.train.raw"
if [ $? -ne 0 ]; then
echo "Build vocabulary failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# compute mean and stddev for normalizer
num_workers=$(nproc)
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.train.raw" \
--num_samples=-1 \
--specgram_type="fbank" \
--feat_dim=80 \
--delta_delta=false \
--sample_rate=16000 \
--stride_ms=10.0 \
--window_ms=25.0 \
--use_dB_normalization=False \
--num_workers=${num_workers} \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for set in train dev test; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type ${unit_type} \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${set}.raw" \
--output_path="data/manifest.${set}"
if [ $? -ne 0 ]; then
echo "Formt mnaifest.${set} failed. Terminated."
exit 1
fi
}&
done
wait
fi
echo "TIMIT Data preparation done."
exit 0
#!/bin/bash
if [ $# != 3 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_path_prefix=$2
jit_model_export_path=$3
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
python3 -u ${BIN_DIR}/export.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \
--export_path ${jit_model_export_path}
if [ $? -ne 0 ]; then
echo "Failed in export!"
exit 1
fi
exit 0
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
chunk_mode=true
fi
# download language model
#bash local/download_lm_en.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
for type in attention ctc_greedy_search; do
echo "decoding ${type}"
if [ ${chunk_mode} == true ];then
# stream decoding only support batchsize=1
batch_size=1
else
batch_size=64
fi
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}"
batch_size=1
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
exit 0
#!/usr/bin/env bash
# Copyright 2013 (Authors: Bagher BabaAli, Daniel Povey, Arnab Ghoshal)
# 2014 Brno University of Technology (Author: Karel Vesely)
# Apache 2.0.
if [ $# -ne 1 ]; then
echo "Argument should be the Timit directory, see ../run.sh for example."
exit 1;
fi
dir=`pwd`/data/local
mkdir -p $dir
local=`pwd`/local
utils=`pwd`/utils
conf=`pwd`/conf
function error_exit () {
echo -e "$@" >&2; exit 1;
}
PROG=$(basename $0)
[ -f $conf/test_spk.list ] || error_exit "$PROG line $LINENO: Eval-set speaker list not found.";
[ -f $conf/dev_spk.list ] || error_exit "$PROG line $LINENO: dev-set speaker list not found.";
# First check if the train & test directories exist (these can either be upper-
# or lower-cased
if [ ! -d $*/TRAIN -o ! -d $*/TEST ] && [ ! -d $*/train -o ! -d $*/test ]; then
echo "timit_data_prep.sh: Spot check of command line argument failed"
echo "Command line argument must be absolute pathname to TIMIT directory"
echo "with name like /export/corpora5/LDC/LDC93S1/timit/TIMIT"
exit 1;
fi
# Now check what case the directory structure is
uppercased=false
train_dir=train
test_dir=test
if [ -d $*/TRAIN ]; then
uppercased=true
train_dir=TRAIN
test_dir=TEST
fi
tmpdir=$(mktemp -d /tmp/kaldi.XXXX);
trap 'rm -rf "$tmpdir"' EXIT
# Get the list of speakers. The list of speakers in the 24-speaker core test
# set and the 50-speaker development set must be supplied to the script. All
# speakers in the 'train' directory are used for training.
if $uppercased; then
tr '[:lower:]' '[:upper:]' < $conf/dev_spk.list > $tmpdir/dev_spk
tr '[:lower:]' '[:upper:]' < $conf/test_spk.list > $tmpdir/test_spk
ls -d "$*"/TRAIN/DR*/* | sed -e "s:^.*/::" > $tmpdir/train_spk
else
tr '[:upper:]' '[:lower:]' < $conf/dev_spk.list > $tmpdir/dev_spk
tr '[:upper:]' '[:lower:]' < $conf/test_spk.list > $tmpdir/test_spk
ls -d "$*"/train/dr*/* | sed -e "s:^.*/::" > $tmpdir/train_spk
fi
cd $dir
for x in train dev test; do
# First, find the list of audio files (use only si & sx utterances).
# Note: train & test sets are under different directories, but doing find on
# both and grepping for the speakers will work correctly.
find $*/{$train_dir,$test_dir} -not \( -iname 'SA*' \) -iname '*.WAV' \
| grep -f $tmpdir/${x}_spk > ${x}_sph.flist
sed -e 's:.*/\(.*\)/\(.*\).\(WAV\|wav\)$:\1_\2:' ${x}_sph.flist \
> $tmpdir/${x}_sph.uttids
paste $tmpdir/${x}_sph.uttids ${x}_sph.flist \
| sort -k1,1 > ${x}_sph.scp
cat ${x}_sph.scp | awk '{print $1}' > ${x}.uttids
# Now, Convert the transcripts into our format (no normalization yet)
# Get the transcripts: each line of the output contains an utterance
# ID followed by the transcript.
find $*/{$train_dir,$test_dir} -not \( -iname 'SA*' \) -iname '*.PHN' \
| grep -f $tmpdir/${x}_spk > $tmpdir/${x}_phn.flist
sed -e 's:.*/\(.*\)/\(.*\).\(PHN\|phn\)$:\1_\2:' $tmpdir/${x}_phn.flist \
> $tmpdir/${x}_phn.uttids
while read line; do
[ -f $line ] || error_exit "Cannot find transcription file '$line'";
cut -f3 -d' ' "$line" | tr '\n' ' ' | perl -ape 's: *$:\n:;'
done < $tmpdir/${x}_phn.flist > $tmpdir/${x}_phn.trans
paste $tmpdir/${x}_phn.uttids $tmpdir/${x}_phn.trans \
| sort -k1,1 > ${x}.trans
# Do normalization steps.
cat ${x}.trans | $local/timit_norm_trans.pl -i - -m $conf/phones.60-48-39.map -to 39 | sort > $x.text || exit 1;
done
echo "Data preparation succeeded"
\ No newline at end of file
#!/usr/bin/env perl
use warnings; #sed replacement for -w perl parameter
# Copyright 2012 Arnab Ghoshal
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script normalizes the TIMIT phonetic transcripts that have been
# extracted in a format where each line contains an utterance ID followed by
# the transcript, e.g.:
# fcke0_si1111 h# hh ah dx ux w iy dcl d ix f ay n ih q h#
my $usage = "Usage: timit_norm_trans.pl -i transcript -m phone_map -from [60|48] -to [48|39] > normalized\n
Normalizes phonetic transcriptions for TIMIT, by mapping the phones to a
smaller set defined by the -m option. This script assumes that the mapping is
done in the \"standard\" fashion, i.e. to 48 or 39 phones. The input is
assumed to have 60 phones (+1 for glottal stop, which is deleted), but that can
be changed using the -from option. The input format is assumed to be utterance
ID followed by transcript on the same line.\n";
use strict;
use Getopt::Long;
die "$usage" unless(@ARGV >= 1);
my ($in_trans, $phone_map, $num_phones_out);
my $num_phones_in = 60;
GetOptions ("i=s" => \$in_trans, # Input transcription
"m=s" => \$phone_map, # File containing phone mappings
"from=i" => \$num_phones_in, # Input #phones: must be 60 or 48
"to=i" => \$num_phones_out ); # Output #phones: must be 48 or 39
die $usage unless(defined($in_trans) && defined($phone_map) &&
defined($num_phones_out));
if ($num_phones_in != 60 && $num_phones_in != 48) {
die "Can only used 60 or 48 for -from (used $num_phones_in)."
}
if ($num_phones_out != 48 && $num_phones_out != 39) {
die "Can only used 48 or 39 for -to (used $num_phones_out)."
}
unless ($num_phones_out < $num_phones_in) {
die "Argument to -from ($num_phones_in) must be greater than that to -to ($num_phones_out)."
}
open(M, "<$phone_map") or die "Cannot open mappings file '$phone_map': $!";
my (%phonemap, %seen_phones);
my $num_seen_phones = 0;
while (<M>) {
chomp;
next if ($_ =~ /^q\s*.*$/); # Ignore glottal stops.
m:^(\S+)\s+(\S+)\s+(\S+)$: or die "Bad line: $_";
my $mapped_from = ($num_phones_in == 60)? $1 : $2;
my $mapped_to = ($num_phones_out == 48)? $2 : $3;
if (!defined($seen_phones{$mapped_to})) {
$seen_phones{$mapped_to} = 1;
$num_seen_phones += 1;
}
$phonemap{$mapped_from} = $mapped_to;
}
if ($num_seen_phones != $num_phones_out) {
die "Trying to map to $num_phones_out phones, but seen only $num_seen_phones";
}
open(T, "<$in_trans") or die "Cannot open transcription file '$in_trans': $!";
while (<T>) {
chomp;
$_ =~ m:^(\S+)\s+(.+): or die "Bad line: $_";
my $utt_id = $1;
my $trans = $2;
$trans =~ s/q//g; # Remove glottal stops.
$trans =~ s/^\s*//; $trans =~ s/\s*$//; # Normalize spaces
print $utt_id;
for my $phone (split(/\s+/, $trans)) {
if(exists $phonemap{$phone}) { print " $phonemap{$phone}"; }
if(not exists $phonemap{$phone}) { print " $phone"; }
}
print "\n";
}
\ No newline at end of file
#!/bin/bash
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
echo "using ${device}..."
mkdir -p exp
python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
if [ $? -ne 0 ]; then
echo "Failed in training!"
exit 1
fi
exit 0
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=u2
export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin
#!/bin/bash
set -e
source path.sh
stage=0
stop_stage=50
conf_path=conf/transformer.yaml
avg_num=10
TIMIT_path= #path of TIMIT (Required, e.g. /export/corpora5/LDC/LDC93S1/timit/TIMIT)
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
bash ./local/timit_data_prep.sh ${TIMIT_path}
bash ./local/data.sh || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=0,1,2,3 ./local/train.sh ${conf_path} ${ckpt}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
avg.sh exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi
......@@ -2,32 +2,38 @@
data:
train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny
mean_std_filepath: data/mean_std.json
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
batch_size: 4
test_manifest: data/manifest.tiny
min_input_len: 0.0
max_input_len: 27.0
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear
target_sample_rate: 16000
max_freq: None
n_fft: None
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
delta_delta: False
dither: 1.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
random_seed: 0
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
num_workers: 2
batch_size: 4
model:
num_conv_layers: 2
......@@ -37,12 +43,16 @@ model:
share_rnn_weights: True
training:
n_epoch: 20
n_epoch: 10
lr: 1e-5
lr_decay: 1.0
weight_decay: 1e-06
global_grad_clip: 5.0
log_interval: 1
checkpoint:
kbest_n: 3
latest_n: 2
decoding:
batch_size: 128
......
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny
min_input_len: 0.0
max_input_len: 27.0
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
batch_size: 4
model:
num_conv_layers: 2
num_rnn_layers: 4
rnn_layer_size: 2048
rnn_direction: forward
num_fc_layers: 2
fc_layers_size_list: 512, 256
use_gru: True
training:
n_epoch: 10
lr: 1e-5
lr_decay: 1.0
weight_decay: 1e-06
global_grad_clip: 5.0
log_interval: 1
checkpoint:
kbest_n: 3
latest_n: 2
decoding:
batch_size: 128
error_rate_type: wer
decoding_method: ctc_beam_search
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 500
cutoff_prob: 1.0
cutoff_top_n: 40
num_proc_bsearch: 8
#! /usr/bin/env bash
#!/bin/bash
stage=-1
stop_stage=100
......
#! /usr/bin/env bash
#!/bin/bash
. ${MAIN_ROOT}/utils/utility.sh
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 3 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path"
if [ $# != 4 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path model_type"
exit -1
fi
......@@ -11,9 +11,10 @@ echo "using $ngpu gpus..."
config_path=$1
ckpt_path_prefix=$2
jit_model_export_path=$3
model_type=$4
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
......@@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
--nproc ${ngpu} \
--config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \
--export_path ${jit_model_export_path}
--export_path ${jit_model_export_path} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in export!"
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1
fi
......@@ -9,11 +9,12 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
model_type=$3
# download language model
bash local/download_lm_en.sh
......@@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix}
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
if [ $# != 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type"
exit -1
fi
......@@ -10,9 +10,10 @@ echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
model_type=$3
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
......@@ -22,7 +23,8 @@ python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
--output exp/${ckpt_name} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in training!"
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 1 ];then
echo "usage: tune ckpt_path"
......
export MAIN_ROOT=${PWD}/../../../
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
......
......@@ -7,11 +7,12 @@ stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=1
model_type=offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') ###ckpt = deepspeech2
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
......@@ -21,20 +22,20 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt}
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
./local/avg.sh exp/${ckpt}/checkpoints ${avg_num}
avg.sh exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES=${gpus} ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
CUDA_VISIBLE_DEVICES=${gpus} ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}
fi
......@@ -3,18 +3,20 @@ data:
train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
mean_std_filepath: ""
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 4
min_input_len: 0.5
max_input_len: 20.0
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
......@@ -91,6 +93,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 1
checkpoint:
kbest_n: 10
latest_n: 1
decoding:
......
......@@ -3,18 +3,20 @@ data:
train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 4
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
mean_std_filepath: ""
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
augmentation_config: conf/augmentation.json
batch_size: 4
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
......@@ -84,6 +86,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 1
checkpoint:
kbest_n: 10
latest_n: 1
decoding:
......
......@@ -3,18 +3,20 @@ data:
train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
mean_std_filepath: ""
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 4
min_input_len: 0.5
max_input_len: 20.0
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
......@@ -87,6 +89,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 1
checkpoint:
kbest_n: 10
latest_n: 1
decoding:
......
......@@ -3,18 +3,20 @@ data:
train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 4
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
mean_std_filepath: ""
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
augmentation_config: conf/augmentation.json
batch_size: 4
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
......@@ -33,7 +35,6 @@ data:
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
......@@ -70,7 +71,7 @@ model:
training:
n_epoch: 20
n_epoch: 21
accum_grad: 1
global_grad_clip: 5.0
optim: adam
......@@ -82,10 +83,13 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 1
checkpoint:
kbest_n: 10
latest_n: 1
decoding:
batch_size: 64
batch_size: 8 #64
error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
......
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0
#! /usr/bin/env bash
#!/bin/bash
stage=-1
stop_stage=100
......
../../s0/local/download_lm_en.sh
\ No newline at end of file
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 3 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path"
......@@ -13,7 +13,7 @@ ckpt_path_prefix=$2
jit_model_export_path=$3
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
......@@ -9,29 +9,60 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
chunk_mode=true
fi
# download language model
#bash local/download_lm_en.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix}
for type in attention ctc_greedy_search; do
echo "decoding ${type}"
if [ ${chunk_mode} == true ];then
# stream decoding only support batchsize=1
batch_size=1
else
batch_size=64
fi
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}"
batch_size=1
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
exit 0
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
......@@ -12,7 +12,7 @@ config_path=$1
ckpt_name=$2
device=gpu
if [ ngpu == 0 ];then
if [ ${ngpu} == 0 ];then
device=cpu
fi
......
export MAIN_ROOT=${PWD}/../../../
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
......
......@@ -20,12 +20,12 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=4,5,6,7 ./local/train.sh ${conf_path} ${ckpt}
./local/train.sh ${conf_path} ${ckpt}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
./local/avg.sh exp/${ckpt}/checkpoints ${avg_num}
avg.sh exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
......@@ -34,6 +34,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi
......@@ -9,14 +9,21 @@ if [ $(id -u) -eq 0 ]; then
fi
if [ -e /etc/lsb-release ];then
#${SUDO} apt-get update
${SUDO} apt-get install -y vim tig tree sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
${SUDO} apt-get update -y
${SUDO} apt-get install -y jq vim tig tree sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
if [ $? != 0 ]; then
error_msg "Please using Ubuntu or install pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev by user."
exit -1
fi
fi
# tools/make
rm tools/*.done
pushd tools && make && popd
source tools/venv/bin/activate
# install python dependencies
if [ -f "requirements.txt" ]; then
pip3 install -r requirements.txt
......@@ -43,6 +50,22 @@ if [ $? != 0 ]; then
rm libsndfile-1.0.28.tar.gz
fi
#install auto-log
python -c "import auto_log"
if [ $? != 0 ]; then
info_msg "Install auto_log into default system path"
test -d AutoLog || git clone https://github.com/LDOUBLEV/AutoLog
if [ $? != 0 ]; then
error_msg "Download auto_log failed !!!"
exit 1
fi
cd AutoLog
pip install -r requirements.txt
python setup.py install
cd ..
rm -rf AutoLog
fi
# install decoders
python3 -c "import pkg_resources; pkg_resources.require(\"swig_decoders==1.1\")"
if [ $? != 0 ]; then
......@@ -66,4 +89,5 @@ if [ $? != 0 ]; then
fi
popd
info_msg "Install all dependencies successfully."
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project(deepspeech VERSION 0.1)
set(CMAKE_VERBOSE_MAKEFILE on)
# set std-14
set(CMAKE_CXX_STANDARD 14)
# include file
include(FetchContent)
include(ExternalProject)
# fc_patch dir
set(FETCHCONTENT_QUIET off)
get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
set(FETCHCONTENT_BASE_DIR ${fc_patch})
###############################################################################
# Option Configurations
###############################################################################
# option configurations
option(TEST_DEBUG "option for debug" OFF)
###############################################################################
# Include third party
###############################################################################
# #example for include third party
# FetchContent_Declare()
# # FetchContent_MakeAvailable was not added until CMake 3.14
# FetchContent_MakeAvailable()
# include_directories()
# ABSEIL-CPP
include(FetchContent)
FetchContent_Declare(
absl
GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git"
GIT_TAG "20210324.1"
)
FetchContent_MakeAvailable(absl)
# libsndfile
include(FetchContent)
FetchContent_Declare(
libsndfile
GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git"
GIT_TAG "1.0.31"
)
FetchContent_MakeAvailable(libsndfile)
###############################################################################
# Add local library
###############################################################################
# system lib
find_package()
# if dir have CmakeLists.txt
add_subdirectory()
# if dir do not have CmakeLists.txt
add_library(lib_name STATIC file.cc)
target_link_libraries(lib_name item0 item1)
add_dependencies(lib_name depend-target)
###############################################################################
# Library installation
###############################################################################
install()
###############################################################################
# Build binary file
###############################################################################
add_executable()
target_link_libraries()
aux_source_directory(. DIR_LIB_SRCS)
add_library(decoder STATIC ${DIR_LIB_SRCS})
......@@ -16,7 +16,7 @@ import unittest
import numpy as np
import paddle
from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.models.ds2 import DeepSpeech2Model
class TestDeepSpeech2Model(unittest.TestCase):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
class TestDeepSpeech2ModelOnline(unittest.TestCase):
def setUp(self):
paddle.set_device('cpu')
self.batch_size = 2
self.feat_dim = 161
max_len = 210
# (B, T, D)
audio = np.random.randn(self.batch_size, max_len, self.feat_dim)
audio_len = np.random.randint(max_len, size=self.batch_size)
audio_len[-1] = max_len
# (B, U)
text = np.array([[1, 2], [1, 2]])
text_len = np.array([2] * self.batch_size)
self.audio = paddle.to_tensor(audio, dtype='float32')
self.audio_len = paddle.to_tensor(audio_len, dtype='int64')
self.text = paddle.to_tensor(text, dtype='int32')
self.text_len = paddle.to_tensor(text_len, dtype='int64')
def test_ds2_1(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_2(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=True)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_3(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_4(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=True)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_5(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_6(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
rnn_direction='bidirect',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_7(self):
use_gru = False
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=1,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=use_gru)
model.eval()
paddle.device.set_device("cpu")
de_ch_size = 8
eouts, eouts_lens, final_state_h_box, final_state_c_box = model.encoder(
self.audio, self.audio_len)
eouts_by_chk_list, eouts_lens_by_chk_list, final_state_h_box_chk, final_state_c_box_chk = model.encoder.forward_chunk_by_chunk(
self.audio, self.audio_len, de_ch_size)
eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1)
eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list)
decode_max_len = eouts.shape[1]
eouts_by_chk = eouts_by_chk[:, :decode_max_len, :]
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False:
self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
def test_ds2_8(self):
use_gru = True
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=1,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=use_gru)
model.eval()
paddle.device.set_device("cpu")
de_ch_size = 8
eouts, eouts_lens, final_state_h_box, final_state_c_box = model.encoder(
self.audio, self.audio_len)
eouts_by_chk_list, eouts_lens_by_chk_list, final_state_h_box_chk, final_state_c_box_chk = model.encoder.forward_chunk_by_chunk(
self.audio, self.audio_len, de_ch_size)
eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1)
eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list)
decode_max_len = eouts.shape[1]
eouts_by_chk = eouts_by_chk[:, :decode_max_len, :]
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False:
self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
if __name__ == '__main__':
unittest.main()
"""
Module containing all the spectrogram classes
"""
# 0.2.0
import torch
import torch.nn as nn
from torch.nn.functional import conv1d, conv2d, fold
import scipy # used only in CFP
import numpy as np
from time import time
# from nnAudio.librosa_functions import * # For debug purpose
# from nnAudio.utils import *
from .librosa_functions import *
from .utils import *
sz_float = 4 # size of a float
epsilon = 10e-8 # fudge factor for normalization
### --------------------------- Spectrogram Classes ---------------------------###
class STFT(torch.nn.Module):
"""This function is to calculate the short-time Fourier transform (STFT) of the input signal.
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
The correct shape will be inferred automatically if the input follows these 3 shapes.
Most of the arguments follow the convention from librosa.
This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``.
Parameters
----------
n_fft : int
Size of Fourier transform. Default value is 2048.
win_length : int
the size of window frame and STFT filter.
Default: None (treated as equal to n_fft)
freq_bins : int
Number of frequency bins. Default is ``None``, which means ``n_fft//2+1`` bins.
hop_length : int
The hop (or stride) size. Default value is ``None`` which is equivalent to ``n_fft//4``.
window : str
The windowing function for STFT. It uses ``scipy.signal.get_window``, please refer to
scipy documentation for possible windowing functions. The default value is 'hann'.
freq_scale : 'linear', 'log', or 'no'
Determine the spacing between each frequency bin. When `linear` or `log` is used,
the bin spacing can be controlled by ``fmin`` and ``fmax``. If 'no' is used, the bin will
start at 0Hz and end at Nyquist frequency with linear spacing.
center : bool
Putting the STFT keneral at the center of the time-step or not. If ``False``, the time
index is the beginning of the STFT kernel, if ``True``, the time index is the center of
the STFT kernel. Default value if ``True``.
pad_mode : str
The padding method. Default value is 'reflect'.
iSTFT : bool
To activate the iSTFT module or not. By default, it is False to save GPU memory.
Note: The iSTFT kernel is not trainable. If you want
a trainable iSTFT, use the iSTFT module.
fmin : int
The starting frequency for the lowest frequency bin. If freq_scale is ``no``, this argument
does nothing.
fmax : int
The ending frequency for the highest frequency bin. If freq_scale is ``no``, this argument
does nothing.
sr : int
The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``.
Setting the correct sampling rate is very important for calculating the correct frequency.
trainable : bool
Determine if the STFT kenrels are trainable or not. If ``True``, the gradients for STFT
kernels will also be caluclated and the STFT kernels will be updated during model training.
Default value is ``False``
output_format : str
Control the spectrogram output type, either ``Magnitude``, ``Complex``, or ``Phase``.
The output_format can also be changed during the ``forward`` method.
verbose : bool
If ``True``, it shows layer information. If ``False``, it suppresses all prints
Returns
-------
spectrogram : torch.tensor
It returns a tensor of spectrograms.
``shape = (num_samples, freq_bins,time_steps)`` if ``output_format='Magnitude'``;
``shape = (num_samples, freq_bins,time_steps, 2)`` if ``output_format='Complex' or 'Phase'``;
Examples
--------
>>> spec_layer = Spectrogram.STFT()
>>> specs = spec_layer(x)
"""
def __init__(self, n_fft=2048, win_length=None, freq_bins=None, hop_length=None, window='hann',
freq_scale='no', center=True, pad_mode='reflect', iSTFT=False,
fmin=50, fmax=6000, sr=22050, trainable=False,
output_format="Complex", verbose=True):
super().__init__()
# Trying to make the default setting same as librosa
if win_length==None: win_length = n_fft
if hop_length==None: hop_length = int(win_length // 4)
self.output_format = output_format
self.trainable = trainable
self.stride = hop_length
self.center = center
self.pad_mode = pad_mode
self.n_fft = n_fft
self.freq_bins = freq_bins
self.trainable = trainable
self.pad_amount = self.n_fft // 2
self.window = window
self.win_length = win_length
self.iSTFT = iSTFT
self.trainable = trainable
start = time()
# Create filter windows for stft
kernel_sin, kernel_cos, self.bins2freq, self.bin_list, window_mask = create_fourier_kernels(n_fft,
win_length=win_length,
freq_bins=freq_bins,
window=window,
freq_scale=freq_scale,
fmin=fmin,
fmax=fmax,
sr=sr,
verbose=verbose)
kernel_sin = torch.tensor(kernel_sin, dtype=torch.float)
kernel_cos = torch.tensor(kernel_cos, dtype=torch.float)
# In this way, the inverse kernel and the forward kernel do not share the same memory...
kernel_sin_inv = torch.cat((kernel_sin, -kernel_sin[1:-1].flip(0)), 0)
kernel_cos_inv = torch.cat((kernel_cos, kernel_cos[1:-1].flip(0)), 0)
if iSTFT:
self.register_buffer('kernel_sin_inv', kernel_sin_inv.unsqueeze(-1))
self.register_buffer('kernel_cos_inv', kernel_cos_inv.unsqueeze(-1))
# Making all these variables nn.Parameter, so that the model can be used with nn.Parallel
# self.kernel_sin = torch.nn.Parameter(self.kernel_sin, requires_grad=self.trainable)
# self.kernel_cos = torch.nn.Parameter(self.kernel_cos, requires_grad=self.trainable)
# Applying window functions to the Fourier kernels
if window:
window_mask = torch.tensor(window_mask)
wsin = kernel_sin * window_mask
wcos = kernel_cos * window_mask
else:
wsin = kernel_sin
wcos = kernel_cos
if self.trainable==False:
self.register_buffer('wsin', wsin)
self.register_buffer('wcos', wcos)
if self.trainable==True:
wsin = torch.nn.Parameter(wsin, requires_grad=self.trainable)
wcos = torch.nn.Parameter(wcos, requires_grad=self.trainable)
self.register_parameter('wsin', wsin)
self.register_parameter('wcos', wcos)
# Prepare the shape of window mask so that it can be used later in inverse
self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1))
if verbose==True:
print("STFT kernels created, time used = {:.4f} seconds".format(time()-start))
else:
pass
def forward(self, x, output_format=None):
"""
Convert a batch of waveforms to spectrograms.
Parameters
----------
x : torch tensor
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
It will be automatically broadcast to the right shape
output_format : str
Control the type of spectrogram to be return. Can be either ``Magnitude`` or ``Complex`` or ``Phase``.
Default value is ``Complex``.
"""
output_format = output_format or self.output_format
self.num_samples = x.shape[-1]
x = broadcast_dim(x)
if self.center:
if self.pad_mode == 'constant':
padding = nn.ConstantPad1d(self.pad_amount, 0)
elif self.pad_mode == 'reflect':
if self.num_samples < self.pad_amount:
raise AssertionError("Signal length shorter than reflect padding length (n_fft // 2).")
padding = nn.ReflectionPad1d(self.pad_amount)
x = padding(x)
spec_imag = conv1d(x, self.wsin, stride=self.stride)
spec_real = conv1d(x, self.wcos, stride=self.stride) # Doing STFT by using conv1d
# remove redundant parts
spec_real = spec_real[:, :self.freq_bins, :]
spec_imag = spec_imag[:, :self.freq_bins, :]
if output_format=='Magnitude':
spec = spec_real.pow(2) + spec_imag.pow(2)
if self.trainable==True:
return torch.sqrt(spec+1e-8) # prevent Nan gradient when sqrt(0) due to output=0
else:
return torch.sqrt(spec)
elif output_format=='Complex':
return torch.stack((spec_real,-spec_imag), -1) # Remember the minus sign for imaginary part
elif output_format=='Phase':
return torch.atan2(-spec_imag+0.0,spec_real) # +0.0 removes -0.0 elements, which leads to error in calculating phase
def inverse(self, X, onesided=True, length=None, refresh_win=True):
"""
This function is same as the :func:`~nnAudio.Spectrogram.iSTFT` class,
which is to convert spectrograms back to waveforms.
It only works for the complex value spectrograms. If you have the magnitude spectrograms,
please use :func:`~nnAudio.Spectrogram.Griffin_Lim`.
Parameters
----------
onesided : bool
If your spectrograms only have ``n_fft//2+1`` frequency bins, please use ``onesided=True``,
else use ``onesided=False``
length : int
To make sure the inverse STFT has the same output length of the original waveform, please
set `length` as your intended waveform length. By default, ``length=None``,
which will remove ``n_fft//2`` samples from the start and the end of the output.
refresh_win : bool
Recalculating the window sum square. If you have an input with fixed number of timesteps,
you can increase the speed by setting ``refresh_win=False``. Else please keep ``refresh_win=True``
"""
if (hasattr(self, 'kernel_sin_inv') != True) or (hasattr(self, 'kernel_cos_inv') != True):
raise NameError("Please activate the iSTFT module by setting `iSTFT=True` if you want to use `inverse`")
assert X.dim()==4 , "Inverse iSTFT only works for complex number," \
"make sure our tensor is in the shape of (batch, freq_bins, timesteps, 2)."\
"\nIf you have a magnitude spectrogram, please consider using Griffin-Lim."
if onesided:
X = extend_fbins(X) # extend freq
X_real, X_imag = X[:, :, :, 0], X[:, :, :, 1]
# broadcast dimensions to support 2D convolution
X_real_bc = X_real.unsqueeze(1)
X_imag_bc = X_imag.unsqueeze(1)
a1 = conv2d(X_real_bc, self.kernel_cos_inv, stride=(1,1))
b2 = conv2d(X_imag_bc, self.kernel_sin_inv, stride=(1,1))
# compute real and imag part. signal lies in the real part
real = a1 - b2
real = real.squeeze(-2)*self.window_mask
# Normalize the amplitude with n_fft
real /= (self.n_fft)
# Overlap and Add algorithm to connect all the frames
real = overlap_add(real, self.stride)
# Prepare the window sumsqure for division
# Only need to create this window once to save time
# Unless the input spectrograms have different time steps
if hasattr(self, 'w_sum')==False or refresh_win==True:
self.w_sum = torch_window_sumsquare(self.window_mask.flatten(), X.shape[2], self.stride, self.n_fft).flatten()
self.nonzero_indices = (self.w_sum>1e-10)
else:
pass
real[:, self.nonzero_indices] = real[:,self.nonzero_indices].div(self.w_sum[self.nonzero_indices])
# Remove padding
if length is None:
if self.center:
real = real[:, self.pad_amount:-self.pad_amount]
else:
if self.center:
real = real[:, self.pad_amount:self.pad_amount + length]
else:
real = real[:, :length]
return real
def extra_repr(self) -> str:
return 'n_fft={}, Fourier Kernel size={}, iSTFT={}, trainable={}'.format(
self.n_fft, (*self.wsin.shape,), self.iSTFT, self.trainable
)
class MelSpectrogram(torch.nn.Module):
"""This function is to calculate the Melspectrogram of the input signal.
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
The correct shape will be inferred automatically if the input follows these 3 shapes.
Most of the arguments follow the convention from librosa.
This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``.
Parameters
----------
sr : int
The sampling rate for the input audio.
It is used to calculate the correct ``fmin`` and ``fmax``.
Setting the correct sampling rate is very important for calculating the correct frequency.
n_fft : int
The window size for the STFT. Default value is 2048
win_length : int
the size of window frame and STFT filter.
Default: None (treated as equal to n_fft)
n_mels : int
The number of Mel filter banks. The filter banks maps the n_fft to mel bins.
Default value is 128.
hop_length : int
The hop (or stride) size. Default value is 512.
window : str
The windowing function for STFT. It uses ``scipy.signal.get_window``, please refer to
scipy documentation for possible windowing functions. The default value is 'hann'.
center : bool
Putting the STFT keneral at the center of the time-step or not. If ``False``,
the time index is the beginning of the STFT kernel, if ``True``, the time index is the
center of the STFT kernel. Default value if ``True``.
pad_mode : str
The padding method. Default value is 'reflect'.
htk : bool
When ``False`` is used, the Mel scale is quasi-logarithmic. When ``True`` is used, the
Mel scale is logarithmic. The default value is ``False``.
fmin : int
The starting frequency for the lowest Mel filter bank.
fmax : int
The ending frequency for the highest Mel filter bank.
norm :
if 1, divide the triangular mel weights by the width of the mel band
(area normalization, AKA 'slaney' default in librosa).
Otherwise, leave all the triangles aiming for
a peak value of 1.0
trainable_mel : bool
Determine if the Mel filter banks are trainable or not. If ``True``, the gradients for Mel
filter banks will also be calculated and the Mel filter banks will be updated during model
training. Default value is ``False``.
trainable_STFT : bool
Determine if the STFT kenrels are trainable or not. If ``True``, the gradients for STFT
kernels will also be caluclated and the STFT kernels will be updated during model training.
Default value is ``False``.
verbose : bool
If ``True``, it shows layer information. If ``False``, it suppresses all prints.
Returns
-------
spectrogram : torch.tensor
It returns a tensor of spectrograms. shape = ``(num_samples, freq_bins,time_steps)``.
Examples
--------
>>> spec_layer = Spectrogram.MelSpectrogram()
>>> specs = spec_layer(x)
"""
def __init__(self, sr=22050, n_fft=2048, win_length=None, n_mels=128, hop_length=512,
window='hann', center=True, pad_mode='reflect', power=2.0, htk=False,
fmin=0.0, fmax=None, norm=1, trainable_mel=False, trainable_STFT=False,
verbose=True, **kwargs):
super().__init__()
self.stride = hop_length
self.center = center
self.pad_mode = pad_mode
self.n_fft = n_fft
self.power = power
self.trainable_mel = trainable_mel
self.trainable_STFT = trainable_STFT
# Preparing for the stft layer. No need for center
self.stft = STFT(n_fft=n_fft, win_length=win_length, freq_bins=None,
hop_length=hop_length, window=window, freq_scale='no',
center=center, pad_mode=pad_mode, sr=sr, trainable=trainable_STFT,
output_format="Magnitude", verbose=verbose, **kwargs)
# Create filter windows for stft
start = time()
# Creating kernel for mel spectrogram
start = time()
mel_basis = mel(sr, n_fft, n_mels, fmin, fmax, htk=htk, norm=norm)
mel_basis = torch.tensor(mel_basis)
if verbose==True:
print("STFT filter created, time used = {:.4f} seconds".format(time()-start))
print("Mel filter created, time used = {:.4f} seconds".format(time()-start))
else:
pass
if trainable_mel:
# Making everything nn.Parameter, so that this model can support nn.DataParallel
mel_basis = torch.nn.Parameter(mel_basis, requires_grad=trainable_mel)
self.register_parameter('mel_basis', mel_basis)
else:
self.register_buffer('mel_basis', mel_basis)
# if trainable_mel==True:
# self.mel_basis = torch.nn.Parameter(self.mel_basis)
# if trainable_STFT==True:
# self.wsin = torch.nn.Parameter(self.wsin)
# self.wcos = torch.nn.Parameter(self.wcos)
def forward(self, x):
"""
Convert a batch of waveforms to Mel spectrograms.
Parameters
----------
x : torch tensor
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
It will be automatically broadcast to the right shape
"""
x = broadcast_dim(x)
spec = self.stft(x, output_format='Magnitude')**self.power
melspec = torch.matmul(self.mel_basis, spec)
return melspec
def extra_repr(self) -> str:
return 'Mel filter banks size = {}, trainable_mel={}'.format(
(*self.mel_basis.shape,), self.trainable_mel, self.trainable_STFT
)
class MFCC(torch.nn.Module):
"""This function is to calculate the Mel-frequency cepstral coefficients (MFCCs) of the input signal.
This algorithm first extracts Mel spectrograms from the audio clips,
then the discrete cosine transform is calcuated to obtain the final MFCCs.
Therefore, the Mel spectrogram part can be made trainable using
``trainable_mel`` and ``trainable_STFT``.
It only support type-II DCT at the moment. Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
The correct shape will be inferred autommatically if the input follows these 3 shapes.
Most of the arguments follow the convention from librosa.
This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``.
Parameters
----------
sr : int
The sampling rate for the input audio. It is used to calculate the correct ``fmin`` and ``fmax``.
Setting the correct sampling rate is very important for calculating the correct frequency.
n_mfcc : int
The number of Mel-frequency cepstral coefficients
norm : string
The default value is 'ortho'. Normalization for DCT basis
**kwargs
Other arguments for Melspectrogram such as n_fft, n_mels, hop_length, and window
Returns
-------
MFCCs : torch.tensor
It returns a tensor of MFCCs. shape = ``(num_samples, n_mfcc, time_steps)``.
Examples
--------
>>> spec_layer = Spectrogram.MFCC()
>>> mfcc = spec_layer(x)
"""
def __init__(self, sr=22050, n_mfcc=20, norm='ortho', verbose=True, ref=1.0, amin=1e-10, top_db=80.0, **kwargs):
super().__init__()
self.melspec_layer = MelSpectrogram(sr=sr, verbose=verbose, **kwargs)
self.m_mfcc = n_mfcc
# attributes that will be used for _power_to_db
if amin <= 0:
raise ParameterError('amin must be strictly positive')
amin = torch.tensor([amin])
ref = torch.abs(torch.tensor([ref]))
self.register_buffer('amin', amin)
self.register_buffer('ref', ref)
self.top_db = top_db
self.n_mfcc = n_mfcc
def _power_to_db(self, S):
'''
Refer to https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html#power_to_db
for the original implmentation.
'''
log_spec = 10.0 * torch.log10(torch.max(S, self.amin))
log_spec -= 10.0 * torch.log10(torch.max(self.amin, self.ref))
if self.top_db is not None:
if self.top_db < 0:
raise ParameterError('top_db must be non-negative')
# make the dim same as log_spec so that it can be broadcasted
batch_wise_max = log_spec.flatten(1).max(1)[0].unsqueeze(1).unsqueeze(1)
log_spec = torch.max(log_spec, batch_wise_max - self.top_db)
return log_spec
def _dct(self, x, norm=None):
'''
Refer to https://github.com/zh217/torch-dct for the original implmentation.
'''
x = x.permute(0,2,1) # make freq the last axis, since dct applies to the frequency axis
x_shape = x.shape
N = x_shape[-1]
v = torch.cat([x[:, :, ::2], x[:, :, 1::2].flip([2])], dim=2)
Vc = torch.rfft(v, 1, onesided=False)
# TODO: Can make the W_r and W_i trainable here
k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
W_r = torch.cos(k)
W_i = torch.sin(k)
V = Vc[:, :, :, 0] * W_r - Vc[:, :, :, 1] * W_i
if norm == 'ortho':
V[:, :, 0] /= np.sqrt(N) * 2
V[:, :, 1:] /= np.sqrt(N / 2) * 2
V = 2 * V
return V.permute(0,2,1) # swapping back the time axis and freq axis
def forward(self, x):
"""
Convert a batch of waveforms to MFCC.
Parameters
----------
x : torch tensor
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
It will be automatically broadcast to the right shape
"""
x = self.melspec_layer(x)
x = self._power_to_db(x)
x = self._dct(x, norm='ortho')[:,:self.m_mfcc,:]
return x
def extra_repr(self) -> str:
return 'n_mfcc = {}'.format(
(self.n_mfcc)
)
class Gammatonegram(torch.nn.Module):
"""
This function is to calculate the Gammatonegram of the input signal. Input signal should be in either of the following shapes. 1. ``(len_audio)``, 2. ``(num_audio, len_audio)``, 3. ``(num_audio, 1, len_audio)``. The correct shape will be inferred autommatically if the input follows these 3 shapes. This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``.
Parameters
----------
sr : int
The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``. Setting the correct sampling rate is very important for calculating the correct frequency.
n_fft : int
The window size for the STFT. Default value is 2048
n_mels : int
The number of Gammatonegram filter banks. The filter banks maps the n_fft to Gammatone bins. Default value is 64
hop_length : int
The hop (or stride) size. Default value is 512.
window : str
The windowing function for STFT. It uses ``scipy.signal.get_window``, please refer to scipy documentation for possible windowing functions. The default value is 'hann'
center : bool
Putting the STFT keneral at the center of the time-step or not. If ``False``, the time index is the beginning of the STFT kernel, if ``True``, the time index is the center of the STFT kernel. Default value if ``True``.
pad_mode : str
The padding method. Default value is 'reflect'.
htk : bool
When ``False`` is used, the Mel scale is quasi-logarithmic. When ``True`` is used, the Mel scale is logarithmic. The default value is ``False``
fmin : int
The starting frequency for the lowest Gammatone filter bank
fmax : int
The ending frequency for the highest Gammatone filter bank
trainable_mel : bool
Determine if the Gammatone filter banks are trainable or not. If ``True``, the gradients for Mel filter banks will also be caluclated and the Mel filter banks will be updated during model training. Default value is ``False``
trainable_STFT : bool
Determine if the STFT kenrels are trainable or not. If ``True``, the gradients for STFT kernels will also be caluclated and the STFT kernels will be updated during model training. Default value is ``False``
verbose : bool
If ``True``, it shows layer information. If ``False``, it suppresses all prints
Returns
-------
spectrogram : torch.tensor
It returns a tensor of spectrograms. shape = ``(num_samples, freq_bins,time_steps)``.
Examples
--------
>>> spec_layer = Spectrogram.Gammatonegram()
>>> specs = spec_layer(x)
"""
def __init__(self, sr=44100, n_fft=2048, n_bins=64, hop_length=512, window='hann', center=True, pad_mode='reflect',
power=2.0, htk=False, fmin=20.0, fmax=None, norm=1, trainable_bins=False, trainable_STFT=False,
verbose=True):
super(Gammatonegram, self).__init__()
self.stride = hop_length
self.center = center
self.pad_mode = pad_mode
self.n_fft = n_fft
self.power = power
# Create filter windows for stft
start = time()
wsin, wcos, self.bins2freq, _, _ = create_fourier_kernels(n_fft, freq_bins=None, window=window, freq_scale='no',
sr=sr)
wsin = torch.tensor(wsin, dtype=torch.float)
wcos = torch.tensor(wcos, dtype=torch.float)
if trainable_STFT:
wsin = torch.nn.Parameter(wsin, requires_grad=trainable_STFT)
wcos = torch.nn.Parameter(wcos, requires_grad=trainable_STFT)
self.register_parameter('wsin', wsin)
self.register_parameter('wcos', wcos)
else:
self.register_buffer('wsin', wsin)
self.register_buffer('wcos', wcos)
# Creating kenral for Gammatone spectrogram
start = time()
gammatone_basis = gammatone(sr, n_fft, n_bins, fmin, fmax)
gammatone_basis = torch.tensor(gammatone_basis)
if verbose == True:
print("STFT filter created, time used = {:.4f} seconds".format(time() - start))
print("Gammatone filter created, time used = {:.4f} seconds".format(time() - start))
else:
pass
# Making everything nn.Prarmeter, so that this model can support nn.DataParallel
if trainable_bins:
gammatone_basis = torch.nn.Parameter(gammatone_basis, requires_grad=trainable_bins)
self.register_parameter('gammatone_basis', gammatone_basis)
else:
self.register_buffer('gammatone_basis', gammatone_basis)
# if trainable_mel==True:
# self.mel_basis = torch.nn.Parameter(self.mel_basis)
# if trainable_STFT==True:
# self.wsin = torch.nn.Parameter(self.wsin)
# self.wcos = torch.nn.Parameter(self.wcos)
def forward(self, x):
x = broadcast_dim(x)
if self.center:
if self.pad_mode == 'constant':
padding = nn.ConstantPad1d(self.n_fft // 2, 0)
elif self.pad_mode == 'reflect':
padding = nn.ReflectionPad1d(self.n_fft // 2)
x = padding(x)
spec = torch.sqrt(conv1d(x, self.wsin, stride=self.stride).pow(2) \
+ conv1d(x, self.wcos, stride=self.stride).pow(2)) ** self.power # Doing STFT by using conv1d
gammatonespec = torch.matmul(self.gammatone_basis, spec)
return gammatonespec
class CQT1992(torch.nn.Module):
"""
This alogrithm uses the method proposed in [1], which would run extremely slow if low frequencies (below 220Hz)
are included in the frequency bins.
Please refer to :func:`~nnAudio.Spectrogram.CQT1992v2` for a more
computational and memory efficient version.
[1] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a
constant Q transform.” (1992).
This function is to calculate the CQT of the input signal.
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
The correct shape will be inferred autommatically if the input follows these 3 shapes.
Most of the arguments follow the convention from librosa.
This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``.
Parameters
----------
sr : int
The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``.
Setting the correct sampling rate is very important for calculating the correct frequency.
hop_length : int
The hop (or stride) size. Default value is 512.
fmin : float
The frequency for the lowest CQT bin. Default is 32.70Hz, which coresponds to the note C0.
fmax : float
The frequency for the highest CQT bin. Default is ``None``, therefore the higest CQT bin is
inferred from the ``n_bins`` and ``bins_per_octave``.
If ``fmax`` is not ``None``, then the argument ``n_bins`` will be ignored and ``n_bins``
will be calculated automatically. Default is ``None``
n_bins : int
The total numbers of CQT bins. Default is 84. Will be ignored if ``fmax`` is not ``None``.
bins_per_octave : int
Number of bins per octave. Default is 12.
trainable_STFT : bool
Determine if the time to frequency domain transformation kernel for the input audio is trainable or not.
Default is ``False``
trainable_CQT : bool
Determine if the frequency domain CQT kernel is trainable or not.
Default is ``False``
norm : int
Normalization for the CQT kernels. ``1`` means L1 normalization, and ``2`` means L2 normalization.
Default is ``1``, which is same as the normalization used in librosa.
window : str
The windowing function for CQT. It uses ``scipy.signal.get_window``, please refer to
scipy documentation for possible windowing functions. The default value is 'hann'.
center : bool
Putting the CQT keneral at the center of the time-step or not. If ``False``, the time index is
the beginning of the CQT kernel, if ``True``, the time index is the center of the CQT kernel.
Default value if ``True``.
pad_mode : str
The padding method. Default value is 'reflect'.
trainable : bool
Determine if the CQT kernels are trainable or not. If ``True``, the gradients for CQT kernels
will also be caluclated and the CQT kernels will be updated during model training.
Default value is ``False``.
output_format : str
Determine the return type.
``Magnitude`` will return the magnitude of the STFT result, shape = ``(num_samples, freq_bins,time_steps)``;
``Complex`` will return the STFT result in complex number, shape = ``(num_samples, freq_bins,time_steps, 2)``;
``Phase`` will return the phase of the STFT reuslt, shape = ``(num_samples, freq_bins,time_steps, 2)``.
The complex number is stored as ``(real, imag)`` in the last axis. Default value is 'Magnitude'.
verbose : bool
If ``True``, it shows layer information. If ``False``, it suppresses all prints
Returns
-------
spectrogram : torch.tensor
It returns a tensor of spectrograms.
shape = ``(num_samples, freq_bins,time_steps)`` if ``output_format='Magnitude'``;
shape = ``(num_samples, freq_bins,time_steps, 2)`` if ``output_format='Complex' or 'Phase'``;
Examples
--------
>>> spec_layer = Spectrogram.CQT1992v2()
>>> specs = spec_layer(x)
"""
def __init__(self, sr=22050, hop_length=512, fmin=220, fmax=None, n_bins=84,
trainable_STFT=False, trainable_CQT=False, bins_per_octave=12, filter_scale=1,
output_format='Magnitude', norm=1, window='hann', center=True, pad_mode='reflect'):
super().__init__()
# norm arg is not functioning
self.hop_length = hop_length
self.center = center
self.pad_mode = pad_mode
self.norm = norm
self.output_format = output_format
# creating kernels for CQT
Q = float(filter_scale)/(2**(1/bins_per_octave)-1)
print("Creating CQT kernels ...", end='\r')
start = time()
cqt_kernels, self.kernel_width, lenghts, freqs = create_cqt_kernels(Q,
sr,
fmin,
n_bins,
bins_per_octave,
norm,
window,
fmax)
self.register_buffer('lenghts', lenghts)
self.frequencies = freqs
cqt_kernels = fft(cqt_kernels)[:,:self.kernel_width//2+1]
print("CQT kernels created, time used = {:.4f} seconds".format(time()-start))
# creating kernels for stft
# self.cqt_kernels_real*=lenghts.unsqueeze(1)/self.kernel_width # Trying to normalize as librosa
# self.cqt_kernels_imag*=lenghts.unsqueeze(1)/self.kernel_width
print("Creating STFT kernels ...", end='\r')
start = time()
kernel_sin, kernel_cos, self.bins2freq, _, window = create_fourier_kernels(self.kernel_width,
window='ones',
freq_scale='no')
# Converting kernels from numpy arrays to torch tensors
wsin = torch.tensor(kernel_sin * window)
wcos = torch.tensor(kernel_cos * window)
cqt_kernels_real = torch.tensor(cqt_kernels.real.astype(np.float32))
cqt_kernels_imag = torch.tensor(cqt_kernels.imag.astype(np.float32))
if trainable_STFT:
wsin = torch.nn.Parameter(wsin, requires_grad=trainable_STFT)
wcos = torch.nn.Parameter(wcos, requires_grad=trainable_STFT)
self.register_parameter('wsin', wsin)
self.register_parameter('wcos', wcos)
else:
self.register_buffer('wsin', wsin)
self.register_buffer('wcos', wcos)
if trainable_CQT:
cqt_kernels_real = torch.nn.Parameter(cqt_kernels_real, requires_grad=trainable_CQT)
cqt_kernels_imag = torch.nn.Parameter(cqt_kernels_imag, requires_grad=trainable_CQT)
self.register_parameter('cqt_kernels_real', cqt_kernels_real)
self.register_parameter('cqt_kernels_imag', cqt_kernels_imag)
else:
self.register_buffer('cqt_kernels_real', cqt_kernels_real)
self.register_buffer('cqt_kernels_imag', cqt_kernels_imag)
print("STFT kernels created, time used = {:.4f} seconds".format(time()-start))
def forward(self, x, output_format=None, normalization_type='librosa'):
"""
Convert a batch of waveforms to CQT spectrograms.
Parameters
----------
x : torch tensor
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
It will be automatically broadcast to the right shape
"""
output_format = output_format or self.output_format
x = broadcast_dim(x)
if self.center:
if self.pad_mode == 'constant':
padding = nn.ConstantPad1d(self.kernel_width//2, 0)
elif self.pad_mode == 'reflect':
padding = nn.ReflectionPad1d(self.kernel_width//2)
x = padding(x)
# STFT
fourier_real = conv1d(x, self.wcos, stride=self.hop_length)
fourier_imag = conv1d(x, self.wsin, stride=self.hop_length)
# CQT
CQT_real, CQT_imag = complex_mul((self.cqt_kernels_real, self.cqt_kernels_imag),
(fourier_real, fourier_imag))
CQT = torch.stack((CQT_real,-CQT_imag),-1)
if normalization_type == 'librosa':
CQT *= torch.sqrt(self.lenghts.view(-1,1,1))/self.kernel_width
elif normalization_type == 'convolutional':
pass
elif normalization_type == 'wrap':
CQT *= 2/self.kernel_width
else:
raise ValueError("The normalization_type %r is not part of our current options." % normalization_type)
# if self.norm:
# CQT = CQT/self.kernel_width*torch.sqrt(self.lenghts.view(-1,1,1))
# else:
# CQT = CQT*torch.sqrt(self.lenghts.view(-1,1,1))
if output_format=='Magnitude':
# Getting CQT Amplitude
return torch.sqrt(CQT.pow(2).sum(-1))
elif output_format=='Complex':
return CQT
elif output_format=='Phase':
phase_real = torch.cos(torch.atan2(CQT_imag,CQT_real))
phase_imag = torch.sin(torch.atan2(CQT_imag,CQT_real))
return torch.stack((phase_real,phase_imag), -1)
def extra_repr(self) -> str:
return 'STFT kernel size = {}, CQT kernel size = {}'.format(
(*self.wcos.shape,), (*self.cqt_kernels_real.shape,)
)
class CQT2010(torch.nn.Module):
"""
This algorithm is using the resampling method proposed in [1].
Instead of convoluting the STFT results with a gigantic CQT kernel covering the full frequency
spectrum, we make a small CQT kernel covering only the top octave.
Then we keep downsampling the input audio by a factor of 2 to convoluting it with the
small CQT kernel. Everytime the input audio is downsampled, the CQT relative to the downsampled
input is equavalent to the next lower octave.
The kernel creation process is still same as the 1992 algorithm. Therefore, we can reuse the code
from the 1992 alogrithm [2]
[1] Schörkhuber, Christian. “CONSTANT-Q TRANSFORM TOOLBOX FOR MUSIC PROCESSING.” (2010).
[2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a
constant Q transform.” (1992).
early downsampling factor is to downsample the input audio to reduce the CQT kernel size.
The result with and without early downsampling are more or less the same except in the very low
frequency region where freq < 40Hz.
"""
def __init__(self, sr=22050, hop_length=512, fmin=32.70, fmax=None, n_bins=84, bins_per_octave=12,
norm=True, basis_norm=1, window='hann', pad_mode='reflect', trainable_STFT=False, filter_scale=1,
trainable_CQT=False, output_format='Magnitude', earlydownsample=True, verbose=True):
super().__init__()
self.norm = norm # Now norm is used to normalize the final CQT result by dividing n_fft
# basis_norm is for normalizing basis
self.hop_length = hop_length
self.pad_mode = pad_mode
self.n_bins = n_bins
self.output_format = output_format
self.earlydownsample = earlydownsample # TODO: activate early downsampling later if possible
# This will be used to calculate filter_cutoff and creating CQT kernels
Q = float(filter_scale)/(2**(1/bins_per_octave)-1)
# Creating lowpass filter and make it a torch tensor
if verbose==True:
print("Creating low pass filter ...", end='\r')
start = time()
lowpass_filter = torch.tensor(create_lowpass_filter(
band_center = 0.5,
kernelLength=256,
transitionBandwidth=0.001
)
)
# Broadcast the tensor to the shape that fits conv1d
self.register_buffer('lowpass_filter', lowpass_filter[None,None,:])
if verbose==True:
print("Low pass filter created, time used = {:.4f} seconds".format(time()-start))
# Calculate num of filter requires for the kernel
# n_octaves determines how many resampling requires for the CQT
n_filters = min(bins_per_octave, n_bins)
self.n_octaves = int(np.ceil(float(n_bins) / bins_per_octave))
# print("n_octaves = ", self.n_octaves)
# Calculate the lowest frequency bin for the top octave kernel
self.fmin_t = fmin*2**(self.n_octaves-1)
remainder = n_bins % bins_per_octave
# print("remainder = ", remainder)
if remainder==0:
# Calculate the top bin frequency
fmax_t = self.fmin_t*2**((bins_per_octave-1)/bins_per_octave)
else:
# Calculate the top bin frequency
fmax_t = self.fmin_t*2**((remainder-1)/bins_per_octave)
self.fmin_t = fmax_t/2**(1-1/bins_per_octave) # Adjusting the top minium bins
if fmax_t > sr/2:
raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, \
please reduce the n_bins'.format(fmax_t))
if self.earlydownsample == True: # Do early downsampling if this argument is True
if verbose==True:
print("Creating early downsampling filter ...", end='\r')
start = time()
sr, self.hop_length, self.downsample_factor, early_downsample_filter, \
self.earlydownsample = get_early_downsample_params(sr,
hop_length,
fmax_t,
Q,
self.n_octaves,
verbose)
self.register_buffer('early_downsample_filter', early_downsample_filter)
if verbose==True:
print("Early downsampling filter created, \
time used = {:.4f} seconds".format(time()-start))
else:
self.downsample_factor=1.
# Preparing CQT kernels
if verbose==True:
print("Creating CQT kernels ...", end='\r')
start = time()
# print("Q = {}, fmin_t = {}, n_filters = {}".format(Q, self.fmin_t, n_filters))
basis, self.n_fft, _, _ = create_cqt_kernels(Q,
sr,
self.fmin_t,
n_filters,
bins_per_octave,
norm=basis_norm,
topbin_check=False)
# This is for the normalization in the end
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave))
self.frequencies = freqs
lenghts = np.ceil(Q * sr / freqs)
lenghts = torch.tensor(lenghts).float()
self.register_buffer('lenghts', lenghts)
self.basis=basis
fft_basis = fft(basis)[:,:self.n_fft//2+1] # Convert CQT kenral from time domain to freq domain
# These cqt_kernel is already in the frequency domain
cqt_kernels_real = torch.tensor(fft_basis.real.astype(np.float32))
cqt_kernels_imag = torch.tensor(fft_basis.imag.astype(np.float32))
if verbose==True:
print("CQT kernels created, time used = {:.4f} seconds".format(time()-start))
# print("Getting cqt kernel done, n_fft = ",self.n_fft)
# Preparing kernels for Short-Time Fourier Transform (STFT)
# We set the frequency range in the CQT filter instead of here.
if verbose==True:
print("Creating STFT kernels ...", end='\r')
start = time()
kernel_sin, kernel_cos, self.bins2freq, _, window = create_fourier_kernels(self.n_fft, window='ones', freq_scale='no')
wsin = kernel_sin * window
wcos = kernel_cos * window
wsin = torch.tensor(wsin)
wcos = torch.tensor(wcos)
if verbose==True:
print("STFT kernels created, time used = {:.4f} seconds".format(time()-start))
if trainable_STFT:
wsin = torch.nn.Parameter(wsin, requires_grad=trainable_STFT)
wcos = torch.nn.Parameter(wcos, requires_grad=trainable_STFT)
self.register_parameter('wsin', wsin)
self.register_parameter('wcos', wcos)
else:
self.register_buffer('wsin', wsin)
self.register_buffer('wcos', wcos)
if trainable_CQT:
cqt_kernels_real = torch.nn.Parameter(cqt_kernels_real, requires_grad=trainable_CQT)
cqt_kernels_imag = torch.nn.Parameter(cqt_kernels_imag, requires_grad=trainable_CQT)
self.register_parameter('cqt_kernels_real', cqt_kernels_real)
self.register_parameter('cqt_kernels_imag', cqt_kernels_imag)
else:
self.register_buffer('cqt_kernels_real', cqt_kernels_real)
self.register_buffer('cqt_kernels_imag', cqt_kernels_imag)
# If center==True, the STFT window will be put in the middle, and paddings at the beginning
# and ending are required.
if self.pad_mode == 'constant':
self.padding = nn.ConstantPad1d(self.n_fft//2, 0)
elif self.pad_mode == 'reflect':
self.padding = nn.ReflectionPad1d(self.n_fft//2)
def forward(self,x, output_format=None, normalization_type='librosa'):
"""
Convert a batch of waveforms to CQT spectrograms.
Parameters
----------
x : torch tensor
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
It will be automatically broadcast to the right shape
"""
output_format = output_format or self.output_format
x = broadcast_dim(x)
if self.earlydownsample==True:
x = downsampling_by_n(x, self.early_downsample_filter, self.downsample_factor)
hop = self.hop_length
CQT = get_cqt_complex2(x, self.cqt_kernels_real, self.cqt_kernels_imag, hop, self.padding,
wcos=self.wcos, wsin=self.wsin)
x_down = x # Preparing a new variable for downsampling
for i in range(self.n_octaves-1):
hop = hop//2
x_down = downsampling_by_2(x_down, self.lowpass_filter)
CQT1 = get_cqt_complex2(x_down, self.cqt_kernels_real, self.cqt_kernels_imag, hop, self.padding,
wcos=self.wcos, wsin=self.wsin)
CQT = torch.cat((CQT1, CQT),1)
CQT = CQT[:,-self.n_bins:,:] # Removing unwanted top bins
if normalization_type == 'librosa':
CQT *= torch.sqrt(self.lenghts.view(-1,1,1))/self.n_fft
elif normalization_type == 'convolutional':
pass
elif normalization_type == 'wrap':
CQT *= 2/self.n_fft
else:
raise ValueError("The normalization_type %r is not part of our current options." % normalization_type)
if output_format=='Magnitude':
# Getting CQT Amplitude
return torch.sqrt(CQT.pow(2).sum(-1))
elif output_format=='Complex':
return CQT
elif output_format=='Phase':
phase_real = torch.cos(torch.atan2(CQT[:,:,:,1],CQT[:,:,:,0]))
phase_imag = torch.sin(torch.atan2(CQT[:,:,:,1],CQT[:,:,:,0]))
return torch.stack((phase_real,phase_imag), -1)
def extra_repr(self) -> str:
return 'STFT kernel size = {}, CQT kernel size = {}'.format(
(*self.wcos.shape,), (*self.cqt_kernels_real.shape,)
)
class CQT1992v2(torch.nn.Module):
"""This function is to calculate the CQT of the input signal.
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
The correct shape will be inferred autommatically if the input follows these 3 shapes.
Most of the arguments follow the convention from librosa.
This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``.
This alogrithm uses the method proposed in [1]. I slightly modify it so that it runs faster
than the original 1992 algorithm, that is why I call it version 2.
[1] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a
constant Q transform.” (1992).
Parameters
----------
sr : int
The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``.
Setting the correct sampling rate is very important for calculating the correct frequency.
hop_length : int
The hop (or stride) size. Default value is 512.
fmin : float
The frequency for the lowest CQT bin. Default is 32.70Hz, which coresponds to the note C0.
fmax : float
The frequency for the highest CQT bin. Default is ``None``, therefore the higest CQT bin is
inferred from the ``n_bins`` and ``bins_per_octave``.
If ``fmax`` is not ``None``, then the argument ``n_bins`` will be ignored and ``n_bins``
will be calculated automatically. Default is ``None``
n_bins : int
The total numbers of CQT bins. Default is 84. Will be ignored if ``fmax`` is not ``None``.
bins_per_octave : int
Number of bins per octave. Default is 12.
filter_scale : float > 0
Filter scale factor. Values of filter_scale smaller than 1 can be used to improve the time resolution at the
cost of degrading the frequency resolution. Important to note is that setting for example filter_scale = 0.5 and
bins_per_octave = 48 leads to exactly the same time-frequency resolution trade-off as setting filter_scale = 1
and bins_per_octave = 24, but the former contains twice more frequency bins per octave. In this sense, values
filter_scale < 1 can be seen to implement oversampling of the frequency axis, analogously to the use of zero
padding when calculating the DFT.
norm : int
Normalization for the CQT kernels. ``1`` means L1 normalization, and ``2`` means L2 normalization.
Default is ``1``, which is same as the normalization used in librosa.
window : string, float, or tuple
The windowing function for CQT. If it is a string, It uses ``scipy.signal.get_window``. If it is a
tuple, only the gaussian window wanrantees constant Q factor. Gaussian window should be given as a
tuple ('gaussian', att) where att is the attenuation in the border given in dB.
Please refer to scipy documentation for possible windowing functions. The default value is 'hann'.
center : bool
Putting the CQT keneral at the center of the time-step or not. If ``False``, the time index is
the beginning of the CQT kernel, if ``True``, the time index is the center of the CQT kernel.
Default value if ``True``.
pad_mode : str
The padding method. Default value is 'reflect'.
trainable : bool
Determine if the CQT kernels are trainable or not. If ``True``, the gradients for CQT kernels
will also be caluclated and the CQT kernels will be updated during model training.
Default value is ``False``.
output_format : str
Determine the return type.
``Magnitude`` will return the magnitude of the STFT result, shape = ``(num_samples, freq_bins,time_steps)``;
``Complex`` will return the STFT result in complex number, shape = ``(num_samples, freq_bins,time_steps, 2)``;
``Phase`` will return the phase of the STFT reuslt, shape = ``(num_samples, freq_bins,time_steps, 2)``.
The complex number is stored as ``(real, imag)`` in the last axis. Default value is 'Magnitude'.
verbose : bool
If ``True``, it shows layer information. If ``False``, it suppresses all prints
Returns
-------
spectrogram : torch.tensor
It returns a tensor of spectrograms.
shape = ``(num_samples, freq_bins,time_steps)`` if ``output_format='Magnitude'``;
shape = ``(num_samples, freq_bins,time_steps, 2)`` if ``output_format='Complex' or 'Phase'``;
Examples
--------
>>> spec_layer = Spectrogram.CQT1992v2()
>>> specs = spec_layer(x)
"""
def __init__(self, sr=22050, hop_length=512, fmin=32.70, fmax=None, n_bins=84,
bins_per_octave=12, filter_scale=1, norm=1, window='hann', center=True, pad_mode='reflect',
trainable=False, output_format='Magnitude', verbose=True):
super().__init__()
self.trainable = trainable
self.hop_length = hop_length
self.center = center
self.pad_mode = pad_mode
self.output_format = output_format
# creating kernels for CQT
Q = float(filter_scale)/(2**(1/bins_per_octave)-1)
if verbose==True:
print("Creating CQT kernels ...", end='\r')
start = time()
cqt_kernels, self.kernel_width, lenghts, freqs = create_cqt_kernels(Q,
sr,
fmin,
n_bins,
bins_per_octave,
norm,
window,
fmax)
self.register_buffer('lenghts', lenghts)
self.frequencies = freqs
cqt_kernels_real = torch.tensor(cqt_kernels.real).unsqueeze(1)
cqt_kernels_imag = torch.tensor(cqt_kernels.imag).unsqueeze(1)
if trainable:
cqt_kernels_real = torch.nn.Parameter(cqt_kernels_real, requires_grad=trainable)
cqt_kernels_imag = torch.nn.Parameter(cqt_kernels_imag, requires_grad=trainable)
self.register_parameter('cqt_kernels_real', cqt_kernels_real)
self.register_parameter('cqt_kernels_imag', cqt_kernels_imag)
else:
self.register_buffer('cqt_kernels_real', cqt_kernels_real)
self.register_buffer('cqt_kernels_imag', cqt_kernels_imag)
if verbose==True:
print("CQT kernels created, time used = {:.4f} seconds".format(time()-start))
def forward(self,x, output_format=None, normalization_type='librosa'):
"""
Convert a batch of waveforms to CQT spectrograms.
Parameters
----------
x : torch tensor
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
It will be automatically broadcast to the right shape
normalization_type : str
Type of the normalisation. The possible options are: \n
'librosa' : the output fits the librosa one \n
'convolutional' : the output conserves the convolutional inequalities of the wavelet transform:\n
for all p ϵ [1, inf] \n
- || CQT ||_p <= || f ||_p || g ||_1 \n
- || CQT ||_p <= || f ||_1 || g ||_p \n
- || CQT ||_2 = || f ||_2 || g ||_2 \n
'wrap' : wraps positive and negative frequencies into positive frequencies. This means that the CQT of a
sinus (or a cosinus) with a constant amplitude equal to 1 will have the value 1 in the bin corresponding to
its frequency.
"""
output_format = output_format or self.output_format
x = broadcast_dim(x)
if self.center:
if self.pad_mode == 'constant':
padding = nn.ConstantPad1d(self.kernel_width//2, 0)
elif self.pad_mode == 'reflect':
padding = nn.ReflectionPad1d(self.kernel_width//2)
x = padding(x)
# CQT
CQT_real = conv1d(x, self.cqt_kernels_real, stride=self.hop_length)
CQT_imag = -conv1d(x, self.cqt_kernels_imag, stride=self.hop_length)
if normalization_type == 'librosa':
CQT_real *= torch.sqrt(self.lenghts.view(-1, 1))
CQT_imag *= torch.sqrt(self.lenghts.view(-1, 1))
elif normalization_type == 'convolutional':
pass
elif normalization_type == 'wrap':
CQT_real *= 2
CQT_imag *= 2
else:
raise ValueError("The normalization_type %r is not part of our current options." % normalization_type)
if output_format=='Magnitude':
if self.trainable==False:
# Getting CQT Amplitude
CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2))
else:
CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2)+1e-8)
return CQT
elif output_format=='Complex':
return torch.stack((CQT_real,CQT_imag),-1)
elif output_format=='Phase':
phase_real = torch.cos(torch.atan2(CQT_imag,CQT_real))
phase_imag = torch.sin(torch.atan2(CQT_imag,CQT_real))
return torch.stack((phase_real,phase_imag), -1)
def forward_manual(self,x):
"""
Method for debugging
"""
x = broadcast_dim(x)
if self.center:
if self.pad_mode == 'constant':
padding = nn.ConstantPad1d(self.kernel_width//2, 0)
elif self.pad_mode == 'reflect':
padding = nn.ReflectionPad1d(self.kernel_width//2)
x = padding(x)
# CQT
CQT_real = conv1d(x, self.cqt_kernels_real, stride=self.hop_length)
CQT_imag = conv1d(x, self.cqt_kernels_imag, stride=self.hop_length)
# Getting CQT Amplitude
CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2))
return CQT*torch.sqrt(self.lenghts.view(-1,1))
class CQT2010v2(torch.nn.Module):
"""This function is to calculate the CQT of the input signal.
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
The correct shape will be inferred autommatically if the input follows these 3 shapes.
Most of the arguments follow the convention from librosa.
This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``.
This alogrithm uses the resampling method proposed in [1].
Instead of convoluting the STFT results with a gigantic CQT kernel covering the full frequency
spectrum, we make a small CQT kernel covering only the top octave. Then we keep downsampling the
input audio by a factor of 2 to convoluting it with the small CQT kernel.
Everytime the input audio is downsampled, the CQT relative to the downsampled input is equivalent
to the next lower octave.
The kernel creation process is still same as the 1992 algorithm. Therefore, we can reuse the
code from the 1992 alogrithm [2]
[1] Schörkhuber, Christian. “CONSTANT-Q TRANSFORM TOOLBOX FOR MUSIC PROCESSING.” (2010).
[2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a
constant Q transform.” (1992).
Early downsampling factor is to downsample the input audio to reduce the CQT kernel size.
The result with and without early downsampling are more or less the same except in the very low
frequency region where freq < 40Hz.
Parameters
----------
sr : int
The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``.
Setting the correct sampling rate is very important for calculating the correct frequency.
hop_length : int
The hop (or stride) size. Default value is 512.
fmin : float
The frequency for the lowest CQT bin. Default is 32.70Hz, which coresponds to the note C0.
fmax : float
The frequency for the highest CQT bin. Default is ``None``, therefore the higest CQT bin is
inferred from the ``n_bins`` and ``bins_per_octave``. If ``fmax`` is not ``None``, then the
argument ``n_bins`` will be ignored and ``n_bins`` will be calculated automatically.
Default is ``None``
n_bins : int
The total numbers of CQT bins. Default is 84. Will be ignored if ``fmax`` is not ``None``.
bins_per_octave : int
Number of bins per octave. Default is 12.
norm : bool
Normalization for the CQT result.
basis_norm : int
Normalization for the CQT kernels. ``1`` means L1 normalization, and ``2`` means L2 normalization.
Default is ``1``, which is same as the normalization used in librosa.
window : str
The windowing function for CQT. It uses ``scipy.signal.get_window``, please refer to
scipy documentation for possible windowing functions. The default value is 'hann'
pad_mode : str
The padding method. Default value is 'reflect'.
trainable : bool
Determine if the CQT kernels are trainable or not. If ``True``, the gradients for CQT kernels
will also be caluclated and the CQT kernels will be updated during model training.
Default value is ``False``
output_format : str
Determine the return type.
'Magnitude' will return the magnitude of the STFT result, shape = ``(num_samples, freq_bins, time_steps)``;
'Complex' will return the STFT result in complex number, shape = ``(num_samples, freq_bins, time_steps, 2)``;
'Phase' will return the phase of the STFT reuslt, shape = ``(num_samples, freq_bins,time_steps, 2)``.
The complex number is stored as ``(real, imag)`` in the last axis. Default value is 'Magnitude'.
verbose : bool
If ``True``, it shows layer information. If ``False``, it suppresses all prints.
Returns
-------
spectrogram : torch.tensor
It returns a tensor of spectrograms.
shape = ``(num_samples, freq_bins,time_steps)`` if ``output_format='Magnitude'``;
shape = ``(num_samples, freq_bins,time_steps, 2)`` if ``output_format='Complex' or 'Phase'``;
Examples
--------
>>> spec_layer = Spectrogram.CQT2010v2()
>>> specs = spec_layer(x)
"""
# To DO:
# need to deal with the filter and other tensors
def __init__(self, sr=22050, hop_length=512, fmin=32.70, fmax=None, n_bins=84, filter_scale=1,
bins_per_octave=12, norm=True, basis_norm=1, window='hann', pad_mode='reflect',
earlydownsample=True, trainable=False, output_format='Magnitude', verbose=True):
super().__init__()
self.norm = norm # Now norm is used to normalize the final CQT result by dividing n_fft
# basis_norm is for normalizing basis
self.hop_length = hop_length
self.pad_mode = pad_mode
self.n_bins = n_bins
self.earlydownsample = earlydownsample # We will activate early downsampling later if possible
self.trainable = trainable
self.output_format = output_format
# It will be used to calculate filter_cutoff and creating CQT kernels
Q = float(filter_scale)/(2**(1/bins_per_octave)-1)
# Creating lowpass filter and make it a torch tensor
if verbose==True:
print("Creating low pass filter ...", end='\r')
start = time()
# self.lowpass_filter = torch.tensor(
# create_lowpass_filter(
# band_center = 0.50,
# kernelLength=256,
# transitionBandwidth=0.001))
lowpass_filter = torch.tensor(create_lowpass_filter(
band_center = 0.50,
kernelLength=256,
transitionBandwidth=0.001)
)
# Broadcast the tensor to the shape that fits conv1d
self.register_buffer('lowpass_filter', lowpass_filter[None,None,:])
if verbose==True:
print("Low pass filter created, time used = {:.4f} seconds".format(time()-start))
# Caluate num of filter requires for the kernel
# n_octaves determines how many resampling requires for the CQT
n_filters = min(bins_per_octave, n_bins)
self.n_octaves = int(np.ceil(float(n_bins) / bins_per_octave))
if verbose==True:
print("num_octave = ", self.n_octaves)
# Calculate the lowest frequency bin for the top octave kernel
self.fmin_t = fmin*2**(self.n_octaves-1)
remainder = n_bins % bins_per_octave
# print("remainder = ", remainder)
if remainder==0:
# Calculate the top bin frequency
fmax_t = self.fmin_t*2**((bins_per_octave-1)/bins_per_octave)
else:
# Calculate the top bin frequency
fmax_t = self.fmin_t*2**((remainder-1)/bins_per_octave)
self.fmin_t = fmax_t/2**(1-1/bins_per_octave) # Adjusting the top minium bins
if fmax_t > sr/2:
raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, \
please reduce the n_bins'.format(fmax_t))
if self.earlydownsample == True: # Do early downsampling if this argument is True
if verbose==True:
print("Creating early downsampling filter ...", end='\r')
start = time()
sr, self.hop_length, self.downsample_factor, early_downsample_filter, \
self.earlydownsample = get_early_downsample_params(sr,
hop_length,
fmax_t,
Q,
self.n_octaves,
verbose)
self.register_buffer('early_downsample_filter', early_downsample_filter)
if verbose==True:
print("Early downsampling filter created, \
time used = {:.4f} seconds".format(time()-start))
else:
self.downsample_factor=1.
# Preparing CQT kernels
if verbose==True:
print("Creating CQT kernels ...", end='\r')
start = time()
basis, self.n_fft, lenghts, _ = create_cqt_kernels(Q,
sr,
self.fmin_t,
n_filters,
bins_per_octave,
norm=basis_norm,
topbin_check=False)
# For normalization in the end
# The freqs returned by create_cqt_kernels cannot be used
# Since that returns only the top octave bins
# We need the information for all freq bin
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave))
self.frequencies = freqs
lenghts = np.ceil(Q * sr / freqs)
lenghts = torch.tensor(lenghts).float()
self.register_buffer('lenghts', lenghts)
self.basis = basis
# These cqt_kernel is already in the frequency domain
cqt_kernels_real = torch.tensor(basis.real.astype(np.float32)).unsqueeze(1)
cqt_kernels_imag = torch.tensor(basis.imag.astype(np.float32)).unsqueeze(1)
if trainable:
cqt_kernels_real = torch.nn.Parameter(cqt_kernels_real, requires_grad=trainable)
cqt_kernels_imag = torch.nn.Parameter(cqt_kernels_imag, requires_grad=trainable)
self.register_parameter('cqt_kernels_real', cqt_kernels_real)
self.register_parameter('cqt_kernels_imag', cqt_kernels_imag)
else:
self.register_buffer('cqt_kernels_real', cqt_kernels_real)
self.register_buffer('cqt_kernels_imag', cqt_kernels_imag)
if verbose==True:
print("CQT kernels created, time used = {:.4f} seconds".format(time()-start))
# print("Getting cqt kernel done, n_fft = ",self.n_fft)
# If center==True, the STFT window will be put in the middle, and paddings at the beginning
# and ending are required.
if self.pad_mode == 'constant':
self.padding = nn.ConstantPad1d(self.n_fft//2, 0)
elif self.pad_mode == 'reflect':
self.padding = nn.ReflectionPad1d(self.n_fft//2)
def forward(self,x,output_format=None, normalization_type='librosa'):
"""
Convert a batch of waveforms to CQT spectrograms.
Parameters
----------
x : torch tensor
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
It will be automatically broadcast to the right shape
"""
output_format = output_format or self.output_format
x = broadcast_dim(x)
if self.earlydownsample==True:
x = downsampling_by_n(x, self.early_downsample_filter, self.downsample_factor)
hop = self.hop_length
CQT = get_cqt_complex(x, self.cqt_kernels_real, self.cqt_kernels_imag, hop, self.padding) # Getting the top octave CQT
x_down = x # Preparing a new variable for downsampling
for i in range(self.n_octaves-1):
hop = hop//2
x_down = downsampling_by_2(x_down, self.lowpass_filter)
CQT1 = get_cqt_complex(x_down, self.cqt_kernels_real, self.cqt_kernels_imag, hop, self.padding)
CQT = torch.cat((CQT1, CQT),1)
CQT = CQT[:,-self.n_bins:,:] # Removing unwanted bottom bins
# print("downsample_factor = ",self.downsample_factor)
# print(CQT.shape)
# print(self.lenghts.view(-1,1).shape)
# Normalizing the output with the downsampling factor, 2**(self.n_octaves-1) is make it
# same mag as 1992
CQT = CQT*self.downsample_factor
# Normalize again to get same result as librosa
if normalization_type == 'librosa':
CQT = CQT*torch.sqrt(self.lenghts.view(-1,1,1))
elif normalization_type == 'convolutional':
pass
elif normalization_type == 'wrap':
CQT *= 2
else:
raise ValueError("The normalization_type %r is not part of our current options." % normalization_type)
if output_format=='Magnitude':
if self.trainable==False:
# Getting CQT Amplitude
return torch.sqrt(CQT.pow(2).sum(-1))
else:
return torch.sqrt(CQT.pow(2).sum(-1)+1e-8)
elif output_format=='Complex':
return CQT
elif output_format=='Phase':
phase_real = torch.cos(torch.atan2(CQT[:,:,:,1],CQT[:,:,:,0]))
phase_imag = torch.sin(torch.atan2(CQT[:,:,:,1],CQT[:,:,:,0]))
return torch.stack((phase_real,phase_imag), -1)
class CQT(CQT1992v2):
"""An abbreviation for :func:`~nnAudio.Spectrogram.CQT1992v2`. Please refer to the :func:`~nnAudio.Spectrogram.CQT1992v2` documentation"""
pass
# The section below is for developing purpose
# Please don't use the following classes
#
class DFT(torch.nn.Module):
"""
Experimental feature before `torch.fft` was made avaliable.
The inverse function only works for 1 single frame. i.e. input shape = (batch, n_fft, 1)
"""
def __init__(self, n_fft=2048, freq_bins=None, hop_length=512,
window='hann', freq_scale='no', center=True, pad_mode='reflect',
fmin=50, fmax=6000, sr=22050):
super().__init__()
self.stride = hop_length
self.center = center
self.pad_mode = pad_mode
self.n_fft = n_fft
# Create filter windows for stft
wsin, wcos, self.bins2freq = create_fourier_kernels(n_fft=n_fft,
freq_bins=n_fft,
window=window,
freq_scale=freq_scale,
fmin=fmin,
fmax=fmax,
sr=sr)
self.wsin = torch.tensor(wsin, dtype=torch.float)
self.wcos = torch.tensor(wcos, dtype=torch.float)
def forward(self,x):
"""
Convert a batch of waveforms to spectrums.
Parameters
----------
x : torch tensor
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
It will be automatically broadcast to the right shape
"""
x = broadcast_dim(x)
if self.center:
if self.pad_mode == 'constant':
padding = nn.ConstantPad1d(self.n_fft//2, 0)
elif self.pad_mode == 'reflect':
padding = nn.ReflectionPad1d(self.n_fft//2)
x = padding(x)
imag = conv1d(x, self.wsin, stride=self.stride)
real = conv1d(x, self.wcos, stride=self.stride)
return (real, -imag)
def inverse(self,x_real,x_imag):
"""
Convert a batch of waveforms to CQT spectrograms.
Parameters
----------
x_real : torch tensor
Real part of the signal.
x_imag : torch tensor
Imaginary part of the signal.
"""
x_real = broadcast_dim(x_real)
x_imag = broadcast_dim(x_imag)
x_real.transpose_(1,2) # Prepare the right shape to do inverse
x_imag.transpose_(1,2) # Prepare the right shape to do inverse
# if self.center:
# if self.pad_mode == 'constant':
# padding = nn.ConstantPad1d(self.n_fft//2, 0)
# elif self.pad_mode == 'reflect':
# padding = nn.ReflectionPad1d(self.n_fft//2)
# x_real = padding(x_real)
# x_imag = padding(x_imag)
# Watch out for the positive and negative signs
# ifft = e^(+2\pi*j)*X
# ifft(X_real) = (a1, a2)
# ifft(X_imag)*1j = (b1, b2)*1j
# = (-b2, b1)
a1 = conv1d(x_real, self.wcos, stride=self.stride)
a2 = conv1d(x_real, self.wsin, stride=self.stride)
b1 = conv1d(x_imag, self.wcos, stride=self.stride)
b2 = conv1d(x_imag, self.wsin, stride=self.stride)
imag = a2+b1
real = a1-b2
return (real/self.n_fft, imag/self.n_fft)
class iSTFT(torch.nn.Module):
"""This class is to convert spectrograms back to waveforms. It only works for the complex value spectrograms.
If you have the magnitude spectrograms, please use :func:`~nnAudio.Spectrogram.Griffin_Lim`.
The parameters (e.g. n_fft, window) need to be the same as the STFT in order to obtain the correct inverse.
If trainability is not required, it is recommended to use the ``inverse`` method under the ``STFT`` class
to save GPU/RAM memory.
When ``trainable=True`` and ``freq_scale!='no'``, there is no guarantee that the inverse is perfect, please
use with extra care.
Parameters
----------
n_fft : int
The window size. Default value is 2048.
freq_bins : int
Number of frequency bins. Default is ``None``, which means ``n_fft//2+1`` bins
Please make sure the value is the same as the forward STFT.
hop_length : int
The hop (or stride) size. Default value is ``None`` which is equivalent to ``n_fft//4``.
Please make sure the value is the same as the forward STFT.
window : str
The windowing function for iSTFT. It uses ``scipy.signal.get_window``, please refer to
scipy documentation for possible windowing functions. The default value is 'hann'.
Please make sure the value is the same as the forward STFT.
freq_scale : 'linear', 'log', or 'no'
Determine the spacing between each frequency bin. When `linear` or `log` is used,
the bin spacing can be controlled by ``fmin`` and ``fmax``. If 'no' is used, the bin will
start at 0Hz and end at Nyquist frequency with linear spacing.
Please make sure the value is the same as the forward STFT.
center : bool
Putting the iSTFT keneral at the center of the time-step or not. If ``False``, the time
index is the beginning of the iSTFT kernel, if ``True``, the time index is the center of
the iSTFT kernel. Default value if ``True``.
Please make sure the value is the same as the forward STFT.
fmin : int
The starting frequency for the lowest frequency bin. If freq_scale is ``no``, this argument
does nothing. Please make sure the value is the same as the forward STFT.
fmax : int
The ending frequency for the highest frequency bin. If freq_scale is ``no``, this argument
does nothing. Please make sure the value is the same as the forward STFT.
sr : int
The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``.
Setting the correct sampling rate is very important for calculating the correct frequency.
trainable_kernels : bool
Determine if the STFT kenrels are trainable or not. If ``True``, the gradients for STFT
kernels will also be caluclated and the STFT kernels will be updated during model training.
Default value is ``False``.
trainable_window : bool
Determine if the window function is trainable or not.
Default value is ``False``.
verbose : bool
If ``True``, it shows layer information. If ``False``, it suppresses all prints.
Returns
-------
spectrogram : torch.tensor
It returns a batch of waveforms.
Examples
--------
>>> spec_layer = Spectrogram.iSTFT()
>>> specs = spec_layer(x)
"""
def __init__(self, n_fft=2048, win_length=None, freq_bins=None, hop_length=None, window='hann',
freq_scale='no', center=True, fmin=50, fmax=6000, sr=22050, trainable_kernels=False,
trainable_window=False, verbose=True, refresh_win=True):
super().__init__()
# Trying to make the default setting same as librosa
if win_length==None: win_length = n_fft
if hop_length==None: hop_length = int(win_length // 4)
self.n_fft = n_fft
self.win_length = win_length
self.stride = hop_length
self.center = center
self.pad_amount = self.n_fft // 2
self.refresh_win = refresh_win
start = time()
# Create the window function and prepare the shape for batch-wise-time-wise multiplication
# Create filter windows for inverse
kernel_sin, kernel_cos, _, _, window_mask = create_fourier_kernels(n_fft,
win_length=win_length,
freq_bins=n_fft,
window=window,
freq_scale=freq_scale,
fmin=fmin,
fmax=fmax,
sr=sr,
verbose=False)
window_mask = get_window(window,int(win_length), fftbins=True)
# For inverse, the Fourier kernels do not need to be windowed
window_mask = torch.tensor(window_mask).unsqueeze(0).unsqueeze(-1)
# kernel_sin and kernel_cos have the shape (freq_bins, 1, n_fft, 1) to support 2D Conv
kernel_sin = torch.tensor(kernel_sin, dtype=torch.float).unsqueeze(-1)
kernel_cos = torch.tensor(kernel_cos, dtype=torch.float).unsqueeze(-1)
# Decide if the Fourier kernels are trainable
if trainable_kernels:
# Making all these variables trainable
kernel_sin = torch.nn.Parameter(kernel_sin, requires_grad=trainable_kernels)
kernel_cos = torch.nn.Parameter(kernel_cos, requires_grad=trainable_kernels)
self.register_parameter('kernel_sin', kernel_sin)
self.register_parameter('kernel_cos', kernel_cos)
else:
self.register_buffer('kernel_sin', kernel_sin)
self.register_buffer('kernel_cos', kernel_cos)
# Decide if the window function is trainable
if trainable_window:
window_mask = torch.nn.Parameter(window_mask, requires_grad=trainable_window)
self.register_parameter('window_mask', window_mask)
else:
self.register_buffer('window_mask', window_mask)
if verbose==True:
print("iSTFT kernels created, time used = {:.4f} seconds".format(time()-start))
else:
pass
def forward(self, X, onesided=False, length=None, refresh_win=None):
"""
If your spectrograms only have ``n_fft//2+1`` frequency bins, please use ``onesided=True``,
else use ``onesided=False``
To make sure the inverse STFT has the same output length of the original waveform, please
set `length` as your intended waveform length. By default, ``length=None``,
which will remove ``n_fft//2`` samples from the start and the end of the output.
If your input spectrograms X are of the same length, please use ``refresh_win=None`` to increase
computational speed.
"""
if refresh_win==None:
refresh_win=self.refresh_win
assert X.dim()==4 , "Inverse iSTFT only works for complex number," \
"make sure our tensor is in the shape of (batch, freq_bins, timesteps, 2)"
# If the input spectrogram contains only half of the n_fft
# Use extend_fbins function to get back another half
if onesided:
X = extend_fbins(X) # extend freq
X_real, X_imag = X[:, :, :, 0], X[:, :, :, 1]
# broadcast dimensions to support 2D convolution
X_real_bc = X_real.unsqueeze(1)
X_imag_bc = X_imag.unsqueeze(1)
a1 = conv2d(X_real_bc, self.kernel_cos, stride=(1,1))
b2 = conv2d(X_imag_bc, self.kernel_sin, stride=(1,1))
# compute real and imag part. signal lies in the real part
real = a1 - b2
real = real.squeeze(-2)*self.window_mask
# Normalize the amplitude with n_fft
real /= (self.n_fft)
# Overlap and Add algorithm to connect all the frames
real = overlap_add(real, self.stride)
# Prepare the window sumsqure for division
# Only need to create this window once to save time
# Unless the input spectrograms have different time steps
if hasattr(self, 'w_sum')==False or refresh_win==True:
self.w_sum = torch_window_sumsquare(self.window_mask.flatten(), X.shape[2], self.stride, self.n_fft).flatten()
self.nonzero_indices = (self.w_sum>1e-10)
else:
pass
real[:, self.nonzero_indices] = real[:,self.nonzero_indices].div(self.w_sum[self.nonzero_indices])
# Remove padding
if length is None:
if self.center:
real = real[:, self.pad_amount:-self.pad_amount]
else:
if self.center:
real = real[:, self.pad_amount:self.pad_amount + length]
else:
real = real[:, :length]
return real
class Griffin_Lim(torch.nn.Module):
"""
Converting Magnitude spectrograms back to waveforms based on the "fast Griffin-Lim"[1].
This Griffin Lim is a direct clone from librosa.griffinlim.
[1] Perraudin, N., Balazs, P., & Søndergaard, P. L. “A fast Griffin-Lim algorithm,”
IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4), Oct. 2013.
Parameters
----------
n_fft : int
The window size. Default value is 2048.
n_iter=32 : int
The number of iterations for Griffin-Lim. The default value is ``32``
hop_length : int
The hop (or stride) size. Default value is ``None`` which is equivalent to ``n_fft//4``.
Please make sure the value is the same as the forward STFT.
window : str
The windowing function for iSTFT. It uses ``scipy.signal.get_window``, please refer to
scipy documentation for possible windowing functions. The default value is 'hann'.
Please make sure the value is the same as the forward STFT.
center : bool
Putting the iSTFT keneral at the center of the time-step or not. If ``False``, the time
index is the beginning of the iSTFT kernel, if ``True``, the time index is the center of
the iSTFT kernel. Default value if ``True``.
Please make sure the value is the same as the forward STFT.
momentum : float
The momentum for the update rule. The default value is ``0.99``.
device : str
Choose which device to initialize this layer. Default value is 'cpu'
"""
def __init__(self,
n_fft,
n_iter=32,
hop_length=None,
win_length=None,
window='hann',
center=True,
pad_mode='reflect',
momentum=0.99,
device='cpu'):
super().__init__()
self.n_fft = n_fft
self.win_length = win_length
self.n_iter = n_iter
self.center = center
self.pad_mode = pad_mode
self.momentum = momentum
self.device = device
if win_length==None:
self.win_length=n_fft
else:
self.win_length=win_length
if hop_length==None:
self.hop_length = n_fft//4
else:
self.hop_length = hop_length
# Creating window function for stft and istft later
self.w = torch.tensor(get_window(window,
int(self.win_length),
fftbins=True),
device=device).float()
def forward(self, S):
"""
Convert a batch of magnitude spectrograms to waveforms.
Parameters
----------
S : torch tensor
Spectrogram of the shape ``(batch, n_fft//2+1, timesteps)``
"""
assert S.dim()==3 , "Please make sure your input is in the shape of (batch, freq_bins, timesteps)"
# Initializing Random Phase
rand_phase = torch.randn(*S.shape, device=self.device)
angles = torch.empty((*S.shape,2), device=self.device)
angles[:, :,:,0] = torch.cos(2 * np.pi * rand_phase)
angles[:,:,:,1] = torch.sin(2 * np.pi * rand_phase)
# Initializing the rebuilt magnitude spectrogram
rebuilt = torch.zeros(*angles.shape, device=self.device)
for _ in range(self.n_iter):
tprev = rebuilt # Saving previous rebuilt magnitude spec
# spec2wav conversion
# print(f'win_length={self.win_length}\tw={self.w.shape}')
inverse = torch.istft(S.unsqueeze(-1) * angles,
self.n_fft,
self.hop_length,
win_length=self.win_length,
window=self.w,
center=self.center)
# wav2spec conversion
rebuilt = torch.stft(inverse,
self.n_fft,
self.hop_length,
win_length=self.win_length,
window=self.w,
pad_mode=self.pad_mode)
# Phase update rule
angles[:,:,:] = rebuilt[:,:,:] - (self.momentum / (1 + self.momentum)) * tprev[:,:,:]
# Phase normalization
angles = angles.div(torch.sqrt(angles.pow(2).sum(-1)).unsqueeze(-1) + 1e-16) # normalizing the phase
# Using the final phase to reconstruct the waveforms
inverse = torch.istft(S.unsqueeze(-1) * angles,
self.n_fft,
self.hop_length,
win_length=self.win_length,
window=self.w,
center=self.center)
return inverse
class Combined_Frequency_Periodicity(nn.Module):
"""
Vectorized version of the code in https://github.com/leo-so/VocalMelodyExtPatchCNN/blob/master/MelodyExt.py.
This feature is described in 'Combining Spectral and Temporal Representations for Multipitch Estimation of Polyphonic Music'
https://ieeexplore.ieee.org/document/7118691
Under development, please report any bugs you found
"""
def __init__(self,fr=2, fs=16000, hop_length=320,
window_size=2049, fc=80, tc=1/1000,
g=[0.24, 0.6, 1], NumPerOct=48):
super().__init__()
self.window_size = window_size
self.hop_length = hop_length
# variables for STFT part
self.N = int(fs/float(fr)) # Will be used to calculate padding
self.f = fs*np.linspace(0, 0.5, np.round(self.N//2), endpoint=True) # it won't be used but will be returned
self.pad_value = ((self.N-window_size))
# Create window function, always blackmanharris?
h = scipy.signal.blackmanharris(window_size).astype(np.float32) # window function for STFT
self.register_buffer('h',torch.tensor(h))
# variables for CFP
self.NumofLayer = np.size(g)
self.g = g
self.tc_idx = round(fs*tc) # index to filter out top tc_idx and bottom tc_idx bins
self.fc_idx = round(fc/fr) # index to filter out top fc_idx and bottom fc_idx bins
self.HighFreqIdx = int(round((1/tc)/fr)+1)
self.HighQuefIdx = int(round(fs/fc)+1)
# attributes to be returned
self.f = self.f[:self.HighFreqIdx]
self.q = np.arange(self.HighQuefIdx)/float(fs)
# filters for the final step
freq2logfreq_matrix, quef2logfreq_matrix = self.create_logfreq_matrix(self.f, self.q, fr, fc, tc, NumPerOct, fs)
self.register_buffer('freq2logfreq_matrix',torch.tensor(freq2logfreq_matrix.astype(np.float32)))
self.register_buffer('quef2logfreq_matrix',torch.tensor(quef2logfreq_matrix.astype(np.float32)))
def _CFP(self, spec):
spec = torch.relu(spec).pow(self.g[0])
if self.NumofLayer >= 2:
for gc in range(1, self.NumofLayer):
if np.remainder(gc, 2) == 1:
ceps = torch.rfft(spec, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N)
ceps = self.nonlinear_func(ceps, self.g[gc], self.tc_idx)
else:
spec = torch.rfft(ceps, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N)
spec = self.nonlinear_func(spec, self.g[gc], self.fc_idx)
return spec, ceps
def forward(self, x):
tfr0 = torch.stft(x, self.N, hop_length=self.hop_length, win_length=self.window_size,
window=self.h, onesided=False, pad_mode='constant')
tfr0 = torch.sqrt(tfr0.pow(2).sum(-1))/torch.norm(self.h) # calcuate magnitude
tfr0 = tfr0.transpose(1,2)[:,1:-1] #transpose F and T axis and discard first and last frames
# The transpose is necessary for rfft later
# (batch, timesteps, n_fft)
tfr, ceps = self._CFP(tfr0)
# return tfr0
# removing duplicate bins
tfr0 = tfr0[:,:,:int(round(self.N/2))]
tfr = tfr[:,:,:int(round(self.N/2))]
ceps = ceps[:,:,:int(round(self.N/2))]
# Crop up to the highest frequency
tfr0 = tfr0[:,:,:self.HighFreqIdx]
tfr = tfr[:,:,:self.HighFreqIdx]
ceps = ceps[:,:,:self.HighQuefIdx]
tfrL0 = torch.matmul(self.freq2logfreq_matrix, tfr0.transpose(1,2))
tfrLF = torch.matmul(self.freq2logfreq_matrix, tfr.transpose(1,2))
tfrLQ = torch.matmul(self.quef2logfreq_matrix, ceps.transpose(1,2))
Z = tfrLF * tfrLQ
# Only need to calculate this once
self.t = np.arange(self.hop_length,
np.ceil(len(x)/float(self.hop_length))*self.hop_length,
self.hop_length) # it won't be used but will be returned
return Z, tfrL0, tfrLF, tfrLQ
def nonlinear_func(self, X, g, cutoff):
cutoff = int(cutoff)
if g!=0:
X = torch.relu(X)
X[:, :, :cutoff] = 0
X[:, :, -cutoff:] = 0
X = X.pow(g)
else: # when g=0, it converges to log
X = torch.log(X)
X[:, :, :cutoff] = 0
X[:, :, -cutoff:] = 0
return X
def create_logfreq_matrix(self, f, q, fr, fc, tc, NumPerOct, fs):
StartFreq = fc
StopFreq = 1/tc
Nest = int(np.ceil(np.log2(StopFreq/StartFreq))*NumPerOct)
central_freq = [] # A list holding the frequencies in log scale
for i in range(0, Nest):
CenFreq = StartFreq*pow(2, float(i)/NumPerOct)
if CenFreq < StopFreq:
central_freq.append(CenFreq)
else:
break
Nest = len(central_freq)
freq_band_transformation = np.zeros((Nest-1, len(f)), dtype=np.float)
# Calculating the freq_band_transformation
for i in range(1, Nest-1):
l = int(round(central_freq[i-1]/fr))
r = int(round(central_freq[i+1]/fr)+1)
#rounding1
if l >= r-1:
freq_band_transformation[i, l] = 1
else:
for j in range(l, r):
if f[j] > central_freq[i-1] and f[j] < central_freq[i]:
freq_band_transformation[i, j] = (f[j] - central_freq[i-1]) / (central_freq[i] - central_freq[i-1])
elif f[j] > central_freq[i] and f[j] < central_freq[i+1]:
freq_band_transformation[i, j] = (central_freq[i + 1] - f[j]) / (central_freq[i + 1] - central_freq[i])
# Calculating the quef_band_transformation
f = 1/q # divide by 0, do I need to fix this?
quef_band_transformation = np.zeros((Nest-1, len(f)), dtype=np.float)
for i in range(1, Nest-1):
for j in range(int(round(fs/central_freq[i+1])), int(round(fs/central_freq[i-1])+1)):
if f[j] > central_freq[i-1] and f[j] < central_freq[i]:
quef_band_transformation[i, j] = (f[j] - central_freq[i-1])/(central_freq[i] - central_freq[i-1])
elif f[j] > central_freq[i] and f[j] < central_freq[i+1]:
quef_band_transformation[i, j] = (central_freq[i + 1] - f[j]) / (central_freq[i + 1] - central_freq[i])
return freq_band_transformation, quef_band_transformation
class CFP(nn.Module):
"""
This is the modified version so that the number of timesteps fits with other classes
Under development, please report any bugs you found
"""
def __init__(self,fr=2, fs=16000, hop_length=320,
window_size=2049, fc=80, tc=1/1000,
g=[0.24, 0.6, 1], NumPerOct=48):
super().__init__()
self.window_size = window_size
self.hop_length = hop_length
# variables for STFT part
self.N = int(fs/float(fr)) # Will be used to calculate padding
self.f = fs*np.linspace(0, 0.5, np.round(self.N//2), endpoint=True) # it won't be used but will be returned
self.pad_value = ((self.N-window_size))
# Create window function, always blackmanharris?
h = scipy.signal.blackmanharris(window_size).astype(np.float32) # window function for STFT
self.register_buffer('h',torch.tensor(h))
# variables for CFP
self.NumofLayer = np.size(g)
self.g = g
self.tc_idx = round(fs*tc) # index to filter out top tc_idx and bottom tc_idx bins
self.fc_idx = round(fc/fr) # index to filter out top fc_idx and bottom fc_idx bins
self.HighFreqIdx = int(round((1/tc)/fr)+1)
self.HighQuefIdx = int(round(fs/fc)+1)
# attributes to be returned
self.f = self.f[:self.HighFreqIdx]
self.q = np.arange(self.HighQuefIdx)/float(fs)
# filters for the final step
freq2logfreq_matrix, quef2logfreq_matrix = self.create_logfreq_matrix(self.f, self.q, fr, fc, tc, NumPerOct, fs)
self.register_buffer('freq2logfreq_matrix',torch.tensor(freq2logfreq_matrix.astype(np.float32)))
self.register_buffer('quef2logfreq_matrix',torch.tensor(quef2logfreq_matrix.astype(np.float32)))
def _CFP(self, spec):
spec = torch.relu(spec).pow(self.g[0])
if self.NumofLayer >= 2:
for gc in range(1, self.NumofLayer):
if np.remainder(gc, 2) == 1:
ceps = torch.rfft(spec, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N)
ceps = self.nonlinear_func(ceps, self.g[gc], self.tc_idx)
else:
spec = torch.rfft(ceps, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N)
spec = self.nonlinear_func(spec, self.g[gc], self.fc_idx)
return spec, ceps
def forward(self, x):
tfr0 = torch.stft(x, self.N, hop_length=self.hop_length, win_length=self.window_size,
window=self.h, onesided=False, pad_mode='constant')
tfr0 = torch.sqrt(tfr0.pow(2).sum(-1))/torch.norm(self.h) # calcuate magnitude
tfr0 = tfr0.transpose(1,2) #transpose F and T axis and discard first and last frames
# The transpose is necessary for rfft later
# (batch, timesteps, n_fft)
tfr, ceps = self._CFP(tfr0)
# return tfr0
# removing duplicate bins
tfr0 = tfr0[:,:,:int(round(self.N/2))]
tfr = tfr[:,:,:int(round(self.N/2))]
ceps = ceps[:,:,:int(round(self.N/2))]
# Crop up to the highest frequency
tfr0 = tfr0[:,:,:self.HighFreqIdx]
tfr = tfr[:,:,:self.HighFreqIdx]
ceps = ceps[:,:,:self.HighQuefIdx]
tfrL0 = torch.matmul(self.freq2logfreq_matrix, tfr0.transpose(1,2))
tfrLF = torch.matmul(self.freq2logfreq_matrix, tfr.transpose(1,2))
tfrLQ = torch.matmul(self.quef2logfreq_matrix, ceps.transpose(1,2))
Z = tfrLF * tfrLQ
# Only need to calculate this once
self.t = np.arange(self.hop_length,
np.ceil(len(x)/float(self.hop_length))*self.hop_length,
self.hop_length) # it won't be used but will be returned
return Z#, tfrL0, tfrLF, tfrLQ
def nonlinear_func(self, X, g, cutoff):
cutoff = int(cutoff)
if g!=0:
X = torch.relu(X)
X[:, :, :cutoff] = 0
X[:, :, -cutoff:] = 0
X = X.pow(g)
else: # when g=0, it converges to log
X = torch.log(X)
X[:, :, :cutoff] = 0
X[:, :, -cutoff:] = 0
return X
def create_logfreq_matrix(self, f, q, fr, fc, tc, NumPerOct, fs):
StartFreq = fc
StopFreq = 1/tc
Nest = int(np.ceil(np.log2(StopFreq/StartFreq))*NumPerOct)
central_freq = [] # A list holding the frequencies in log scale
for i in range(0, Nest):
CenFreq = StartFreq*pow(2, float(i)/NumPerOct)
if CenFreq < StopFreq:
central_freq.append(CenFreq)
else:
break
Nest = len(central_freq)
freq_band_transformation = np.zeros((Nest-1, len(f)), dtype=np.float)
# Calculating the freq_band_transformation
for i in range(1, Nest-1):
l = int(round(central_freq[i-1]/fr))
r = int(round(central_freq[i+1]/fr)+1)
#rounding1
if l >= r-1:
freq_band_transformation[i, l] = 1
else:
for j in range(l, r):
if f[j] > central_freq[i-1] and f[j] < central_freq[i]:
freq_band_transformation[i, j] = (f[j] - central_freq[i-1]) / (central_freq[i] - central_freq[i-1])
elif f[j] > central_freq[i] and f[j] < central_freq[i+1]:
freq_band_transformation[i, j] = (central_freq[i + 1] - f[j]) / (central_freq[i + 1] - central_freq[i])
# Calculating the quef_band_transformation
f = 1/q # divide by 0, do I need to fix this?
quef_band_transformation = np.zeros((Nest-1, len(f)), dtype=np.float)
for i in range(1, Nest-1):
for j in range(int(round(fs/central_freq[i+1])), int(round(fs/central_freq[i-1])+1)):
if f[j] > central_freq[i-1] and f[j] < central_freq[i]:
quef_band_transformation[i, j] = (f[j] - central_freq[i-1])/(central_freq[i] - central_freq[i-1])
elif f[j] > central_freq[i] and f[j] < central_freq[i+1]:
quef_band_transformation[i, j] = (central_freq[i + 1] - f[j]) / (central_freq[i + 1] - central_freq[i])
return freq_band_transformation, quef_band_transformation
__version__ = "0.2.2"
\ No newline at end of file
"""
Module containing functions cloned from librosa
To make sure nnAudio would not become broken when updating librosa
"""
import numpy as np
import warnings
### ----------------Functions for generating kenral for Mel Spectrogram------------ ###
# This code is equalvant to from librosa.filters import mel
# By doing so, we can run nnAudio without installing librosa
def fft2gammatonemx(sr=20000, n_fft=2048, n_bins=64, width=1.0, fmin=0.0,
fmax=11025, maxlen=1024):
"""
# Ellis' description in MATLAB:
# [wts,cfreqa] = fft2gammatonemx(nfft, sr, nfilts, width, minfreq, maxfreq, maxlen)
# Generate a matrix of weights to combine FFT bins into
# Gammatone bins. nfft defines the source FFT size at
# sampling rate sr. Optional nfilts specifies the number of
# output bands required (default 64), and width is the
# constant width of each band in Bark (default 1).
# minfreq, maxfreq specify range covered in Hz (100, sr/2).
# While wts has nfft columns, the second half are all zero.
# Hence, aud spectrum is
# fft2gammatonemx(nfft,sr)*abs(fft(xincols,nfft));
# maxlen truncates the rows to this many bins.
# cfreqs returns the actual center frequencies of each
# gammatone band in Hz.
#
# 2009/02/22 02:29:25 Dan Ellis dpwe@ee.columbia.edu based on rastamat/audspec.m
# Sat May 27 15:37:50 2017 Maddie Cusimano, mcusi@mit.edu 27 May 2017: convert to python
"""
wts = np.zeros([n_bins, n_fft], dtype=np.float32)
# after Slaney's MakeERBFilters
EarQ = 9.26449;
minBW = 24.7;
order = 1;
nFr = np.array(range(n_bins)) + 1
em = EarQ * minBW
cfreqs = (fmax + em) * np.exp(nFr * (-np.log(fmax + em) + np.log(fmin + em)) / n_bins) - em
cfreqs = cfreqs[::-1]
GTord = 4
ucircArray = np.array(range(int(n_fft / 2 + 1)))
ucirc = np.exp(1j * 2 * np.pi * ucircArray / n_fft);
# justpoles = 0 :taking out the 'if' corresponding to this.
ERB = width * np.power(np.power(cfreqs / EarQ, order) + np.power(minBW, order), 1 / order);
B = 1.019 * 2 * np.pi * ERB;
r = np.exp(-B / sr)
theta = 2 * np.pi * cfreqs / sr
pole = r * np.exp(1j * theta)
T = 1 / sr
ebt = np.exp(B * T);
cpt = 2 * cfreqs * np.pi * T;
ccpt = 2 * T * np.cos(cpt);
scpt = 2 * T * np.sin(cpt);
A11 = -np.divide(np.divide(ccpt, ebt) + np.divide(np.sqrt(3 + 2 ** 1.5) * scpt, ebt), 2);
A12 = -np.divide(np.divide(ccpt, ebt) - np.divide(np.sqrt(3 + 2 ** 1.5) * scpt, ebt), 2);
A13 = -np.divide(np.divide(ccpt, ebt) + np.divide(np.sqrt(3 - 2 ** 1.5) * scpt, ebt), 2);
A14 = -np.divide(np.divide(ccpt, ebt) - np.divide(np.sqrt(3 - 2 ** 1.5) * scpt, ebt), 2);
zros = -np.array([A11, A12, A13, A14]) / T;
wIdx = range(int(n_fft / 2 + 1))
gain = np.abs((-2 * np.exp(4 * 1j * cfreqs * np.pi * T) * T + 2 * np.exp(
-(B * T) + 2 * 1j * cfreqs * np.pi * T) * T * (
np.cos(2 * cfreqs * np.pi * T) - np.sqrt(3 - 2 ** (3 / 2)) * np.sin(
2 * cfreqs * np.pi * T))) * (-2 * np.exp(4 * 1j * cfreqs * np.pi * T) * T + 2 * np.exp(
-(B * T) + 2 * 1j * cfreqs * np.pi * T) * T * (np.cos(2 * cfreqs * np.pi * T) + np.sqrt(
3 - 2 ** (3 / 2)) * np.sin(2 * cfreqs * np.pi * T))) * (
-2 * np.exp(4 * 1j * cfreqs * np.pi * T) * T + 2 * np.exp(
-(B * T) + 2 * 1j * cfreqs * np.pi * T) * T * (
np.cos(2 * cfreqs * np.pi * T) - np.sqrt(3 + 2 ** (3 / 2)) * np.sin(
2 * cfreqs * np.pi * T))) * (
-2 * np.exp(4 * 1j * cfreqs * np.pi * T) * T + 2 * np.exp(
-(B * T) + 2 * 1j * cfreqs * np.pi * T) * T * (
np.cos(2 * cfreqs * np.pi * T) + np.sqrt(3 + 2 ** (3 / 2)) * np.sin(
2 * cfreqs * np.pi * T))) / (
-2 / np.exp(2 * B * T) - 2 * np.exp(4 * 1j * cfreqs * np.pi * T) + 2 * (
1 + np.exp(4 * 1j * cfreqs * np.pi * T)) / np.exp(B * T)) ** 4);
# in MATLAB, there used to be 64 where here it says n_bins:
wts[:, wIdx] = ((T ** 4) / np.reshape(gain, (n_bins, 1))) * np.abs(
ucirc - np.reshape(zros[0], (n_bins, 1))) * np.abs(ucirc - np.reshape(zros[1], (n_bins, 1))) * np.abs(
ucirc - np.reshape(zros[2], (n_bins, 1))) * np.abs(ucirc - np.reshape(zros[3], (n_bins, 1))) * (np.abs(
np.power(np.multiply(np.reshape(pole, (n_bins, 1)) - ucirc, np.conj(np.reshape(pole, (n_bins, 1))) - ucirc),
-GTord)));
wts = wts[:, range(maxlen)];
return wts, cfreqs
def gammatone(sr, n_fft, n_bins=64, fmin=20.0, fmax=None, htk=False,
norm=1, dtype=np.float32):
"""Create a Filterbank matrix to combine FFT bins into Gammatone bins
Parameters
----------
sr : number > 0 [scalar]
sampling rate of the incoming signal
n_fft : int > 0 [scalar]
number of FFT components
n_bins : int > 0 [scalar]
number of Mel bands to generate
fmin : float >= 0 [scalar]
lowest frequency (in Hz)
fmax : float >= 0 [scalar]
highest frequency (in Hz).
If `None`, use `fmax = sr / 2.0`
htk : bool [scalar]
use HTK formula instead of Slaney
norm : {None, 1, np.inf} [scalar]
if 1, divide the triangular mel weights by the width of the mel band
(area normalization). Otherwise, leave all the triangles aiming for
a peak value of 1.0
dtype : np.dtype
The data type of the output basis.
By default, uses 32-bit (single-precision) floating point.
Returns
-------
G : np.ndarray [shape=(n_bins, 1 + n_fft/2)]
Gammatone transform matrix
"""
if fmax is None:
fmax = float(sr) / 2
n_bins = int(n_bins)
weights,_ = fft2gammatonemx(sr=sr, n_fft=n_fft, n_bins=n_bins, fmin=fmin, fmax=fmax, maxlen=int(n_fft//2+1))
return (1/n_fft)*weights
def mel_to_hz(mels, htk=False):
"""Convert mel bin numbers to frequencies
Examples
--------
>>> librosa.mel_to_hz(3)
200.
>>> librosa.mel_to_hz([1,2,3,4,5])
array([ 66.667, 133.333, 200. , 266.667, 333.333])
Parameters
----------
mels : np.ndarray [shape=(n,)], float
mel bins to convert
htk : bool
use HTK formula instead of Slaney
Returns
-------
frequencies : np.ndarray [shape=(n,)]
input mels in Hz
See Also
--------
hz_to_mel
"""
mels = np.asanyarray(mels)
if htk:
return 700.0 * (10.0**(mels / 2595.0) - 1.0)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
if mels.ndim:
# If we have vector data, vectorize
log_t = (mels >= min_log_mel)
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
elif mels >= min_log_mel:
# If we have scalar data, check directly
freqs = min_log_hz * np.exp(logstep * (mels - min_log_mel))
return freqs
def hz_to_mel(frequencies, htk=False):
"""Convert Hz to Mels
Examples
--------
>>> librosa.hz_to_mel(60)
0.9
>>> librosa.hz_to_mel([110, 220, 440])
array([ 1.65, 3.3 , 6.6 ])
Parameters
----------
frequencies : number or np.ndarray [shape=(n,)] , float
scalar or array of frequencies
htk : bool
use HTK formula instead of Slaney
Returns
-------
mels : number or np.ndarray [shape=(n,)]
input frequencies in Mels
See Also
--------
mel_to_hz
"""
frequencies = np.asanyarray(frequencies)
if htk:
return 2595.0 * np.log10(1.0 + frequencies / 700.0)
# Fill in the linear part
f_min = 0.0
f_sp = 200.0 / 3
mels = (frequencies - f_min) / f_sp
# Fill in the log-scale part
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
if frequencies.ndim:
# If we have array data, vectorize
log_t = (frequencies >= min_log_hz)
mels[log_t] = min_log_mel + np.log(frequencies[log_t]/min_log_hz) / logstep
elif frequencies >= min_log_hz:
# If we have scalar data, heck directly
mels = min_log_mel + np.log(frequencies / min_log_hz) / logstep
return mels
def fft_frequencies(sr=22050, n_fft=2048):
'''Alternative implementation of `np.fft.fftfreq`
Parameters
----------
sr : number > 0 [scalar]
Audio sampling rate
n_fft : int > 0 [scalar]
FFT window size
Returns
-------
freqs : np.ndarray [shape=(1 + n_fft/2,)]
Frequencies `(0, sr/n_fft, 2*sr/n_fft, ..., sr/2)`
Examples
--------
>>> librosa.fft_frequencies(sr=22050, n_fft=16)
array([ 0. , 1378.125, 2756.25 , 4134.375,
5512.5 , 6890.625, 8268.75 , 9646.875, 11025. ])
'''
return np.linspace(0,
float(sr) / 2,
int(1 + n_fft//2),
endpoint=True)
def mel_frequencies(n_mels=128, fmin=0.0, fmax=11025.0, htk=False):
"""
This function is cloned from librosa 0.7.
Please refer to the original
`documentation <https://librosa.org/doc/latest/generated/librosa.mel_frequencies.html?highlight=mel_frequencies#librosa.mel_frequencies>`__
for more info.
Parameters
----------
n_mels : int > 0 [scalar]
Number of mel bins.
fmin : float >= 0 [scalar]
Minimum frequency (Hz).
fmax : float >= 0 [scalar]
Maximum frequency (Hz).
htk : bool
If True, use HTK formula to convert Hz to mel.
Otherwise (False), use Slaney's Auditory Toolbox.
Returns
-------
bin_frequencies : ndarray [shape=(n_mels,)]
Vector of n_mels frequencies in Hz which are uniformly spaced on the Mel
axis.
Examples
--------
>>> librosa.mel_frequencies(n_mels=40)
array([ 0. , 85.317, 170.635, 255.952,
341.269, 426.586, 511.904, 597.221,
682.538, 767.855, 853.173, 938.49 ,
1024.856, 1119.114, 1222.042, 1334.436,
1457.167, 1591.187, 1737.532, 1897.337,
2071.84 , 2262.393, 2470.47 , 2697.686,
2945.799, 3216.731, 3512.582, 3835.643,
4188.417, 4573.636, 4994.285, 5453.621,
5955.205, 6502.92 , 7101.009, 7754.107,
8467.272, 9246.028, 10096.408, 11025. ])
"""
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = hz_to_mel(fmin, htk=htk)
max_mel = hz_to_mel(fmax, htk=htk)
mels = np.linspace(min_mel, max_mel, n_mels)
return mel_to_hz(mels, htk=htk)
def mel(sr, n_fft, n_mels=128, fmin=0.0, fmax=None, htk=False,
norm=1, dtype=np.float32):
"""
This function is cloned from librosa 0.7.
Please refer to the original
`documentation <https://librosa.org/doc/latest/generated/librosa.filters.mel.html>`__
for more info.
Create a Filterbank matrix to combine FFT bins into Mel-frequency bins
Parameters
----------
sr : number > 0 [scalar]
sampling rate of the incoming signal
n_fft : int > 0 [scalar]
number of FFT components
n_mels : int > 0 [scalar]
number of Mel bands to generate
fmin : float >= 0 [scalar]
lowest frequency (in Hz)
fmax : float >= 0 [scalar]
highest frequency (in Hz).
If `None`, use `fmax = sr / 2.0`
htk : bool [scalar]
use HTK formula instead of Slaney
norm : {None, 1, np.inf} [scalar]
if 1, divide the triangular mel weights by the width of the mel band
(area normalization). Otherwise, leave all the triangles aiming for
a peak value of 1.0
dtype : np.dtype
The data type of the output basis.
By default, uses 32-bit (single-precision) floating point.
Returns
-------
M : np.ndarray [shape=(n_mels, 1 + n_fft/2)]
Mel transform matrix
Notes
-----
This function caches at level 10.
Examples
--------
>>> melfb = librosa.filters.mel(22050, 2048)
>>> melfb
array([[ 0. , 0.016, ..., 0. , 0. ],
[ 0. , 0. , ..., 0. , 0. ],
...,
[ 0. , 0. , ..., 0. , 0. ],
[ 0. , 0. , ..., 0. , 0. ]])
Clip the maximum frequency to 8KHz
>>> librosa.filters.mel(22050, 2048, fmax=8000)
array([[ 0. , 0.02, ..., 0. , 0. ],
[ 0. , 0. , ..., 0. , 0. ],
...,
[ 0. , 0. , ..., 0. , 0. ],
[ 0. , 0. , ..., 0. , 0. ]])
>>> import matplotlib.pyplot as plt
>>> plt.figure()
>>> librosa.display.specshow(melfb, x_axis='linear')
>>> plt.ylabel('Mel filter')
>>> plt.title('Mel filter bank')
>>> plt.colorbar()
>>> plt.tight_layout()
>>> plt.show()
"""
if fmax is None:
fmax = float(sr) / 2
if norm is not None and norm != 1 and norm != np.inf:
raise ParameterError('Unsupported norm: {}'.format(repr(norm)))
# Initialize the weights
n_mels = int(n_mels)
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
# Center freqs of each FFT bin
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft)
# 'Center freqs' of mel bands - uniformly spaced between limits
mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax, htk=htk)
fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i+2] / fdiff[i+1]
# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
if norm == 1:
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2:n_mels+2] - mel_f[:n_mels])
weights *= enorm[:, np.newaxis]
# Only check weights if f_mel[0] is positive
if not np.all((mel_f[:-2] == 0) | (weights.max(axis=1) > 0)):
# This means we have an empty channel somewhere
warnings.warn('Empty filters detected in mel frequency basis. '
'Some channels will produce empty responses. '
'Try increasing your sampling rate (and fmax) or '
'reducing n_mels.')
return weights
### ------------------End of Functions for generating kenral for Mel Spectrogram ----------------###
### ------------------Functions for making STFT same as librosa ---------------------------------###
def pad_center(data, size, axis=-1, **kwargs):
'''Wrapper for np.pad to automatically center an array prior to padding.
This is analogous to `str.center()`
Examples
--------
>>> # Generate a vector
>>> data = np.ones(5)
>>> librosa.util.pad_center(data, 10, mode='constant')
array([ 0., 0., 1., 1., 1., 1., 1., 0., 0., 0.])
>>> # Pad a matrix along its first dimension
>>> data = np.ones((3, 5))
>>> librosa.util.pad_center(data, 7, axis=0)
array([[ 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0.],
[ 1., 1., 1., 1., 1.],
[ 1., 1., 1., 1., 1.],
[ 1., 1., 1., 1., 1.],
[ 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0.]])
>>> # Or its second dimension
>>> librosa.util.pad_center(data, 7, axis=1)
array([[ 0., 1., 1., 1., 1., 1., 0.],
[ 0., 1., 1., 1., 1., 1., 0.],
[ 0., 1., 1., 1., 1., 1., 0.]])
Parameters
----------
data : np.ndarray
Vector to be padded and centered
size : int >= len(data) [scalar]
Length to pad `data`
axis : int
Axis along which to pad and center the data
kwargs : additional keyword arguments
arguments passed to `np.pad()`
Returns
-------
data_padded : np.ndarray
`data` centered and padded to length `size` along the
specified axis
Raises
------
ParameterError
If `size < data.shape[axis]`
See Also
--------
numpy.pad
'''
kwargs.setdefault('mode', 'constant')
n = data.shape[axis]
lpad = int((size - n) // 2)
lengths = [(0, 0)] * data.ndim
lengths[axis] = (lpad, int(size - n - lpad))
if lpad < 0:
raise ParameterError(('Target size ({:d}) must be '
'at least input size ({:d})').format(size, n))
return np.pad(data, lengths, **kwargs)
### ------------------End of functions for making STFT same as librosa ---------------------------###
"""
Module containing helper functions such as overlap sum and Fourier kernels generators
"""
import torch
from torch.nn.functional import conv1d, fold
import numpy as np
from time import time
import math
from scipy.signal import get_window
from scipy import signal
from scipy import fft
import warnings
from nnAudio.librosa_functions import *
## --------------------------- Filter Design ---------------------------##
def torch_window_sumsquare(w, n_frames, stride, n_fft, power=2):
w_stacks = w.unsqueeze(-1).repeat((1,n_frames)).unsqueeze(0)
# Window length + stride*(frames-1)
output_len = w_stacks.shape[1] + stride*(w_stacks.shape[2]-1)
return fold(w_stacks**power, (1,output_len), kernel_size=(1,n_fft), stride=stride)
def overlap_add(X, stride):
n_fft = X.shape[1]
output_len = n_fft + stride*(X.shape[2]-1)
return fold(X, (1,output_len), kernel_size=(1,n_fft), stride=stride).flatten(1)
def uniform_distribution(r1,r2, *size, device):
return (r1 - r2) * torch.rand(*size, device=device) + r2
def extend_fbins(X):
"""Extending the number of frequency bins from `n_fft//2+1` back to `n_fft` by
reversing all bins except DC and Nyquist and append it on top of existing spectrogram"""
X_upper = torch.flip(X[:,1:-1],(0,1))
X_upper[:,:,:,1] = -X_upper[:,:,:,1] # For the imaganinry part, it is an odd function
return torch.cat((X[:, :, :], X_upper), 1)
def downsampling_by_n(x, filterKernel, n):
"""A helper function that downsamples the audio by a arbitary factor n.
It is used in CQT2010 and CQT2010v2.
Parameters
----------
x : torch.Tensor
The input waveform in ``torch.Tensor`` type with shape ``(batch, 1, len_audio)``
filterKernel : str
Filter kernel in ``torch.Tensor`` type with shape ``(1, 1, len_kernel)``
n : int
The downsampling factor
Returns
-------
torch.Tensor
The downsampled waveform
Examples
--------
>>> x_down = downsampling_by_n(x, filterKernel)
"""
x = conv1d(x,filterKernel,stride=n, padding=(filterKernel.shape[-1]-1)//2)
return x
def downsampling_by_2(x, filterKernel):
"""A helper function that downsamples the audio by half. It is used in CQT2010 and CQT2010v2
Parameters
----------
x : torch.Tensor
The input waveform in ``torch.Tensor`` type with shape ``(batch, 1, len_audio)``
filterKernel : str
Filter kernel in ``torch.Tensor`` type with shape ``(1, 1, len_kernel)``
Returns
-------
torch.Tensor
The downsampled waveform
Examples
--------
>>> x_down = downsampling_by_2(x, filterKernel)
"""
x = conv1d(x,filterKernel,stride=2, padding=(filterKernel.shape[-1]-1)//2)
return x
## Basic tools for computation ##
def nextpow2(A):
"""A helper function to calculate the next nearest number to the power of 2.
Parameters
----------
A : float
A float number that is going to be rounded up to the nearest power of 2
Returns
-------
int
The nearest power of 2 to the input number ``A``
Examples
--------
>>> nextpow2(6)
3
"""
return int(np.ceil(np.log2(A)))
## Basic tools for computation ##
def prepow2(A):
"""A helper function to calculate the next nearest number to the power of 2.
Parameters
----------
A : float
A float number that is going to be rounded up to the nearest power of 2
Returns
-------
int
The nearest power of 2 to the input number ``A``
Examples
--------
>>> nextpow2(6)
3
"""
return int(np.floor(np.log2(A)))
def complex_mul(cqt_filter, stft):
"""Since PyTorch does not support complex numbers and its operation.
We need to write our own complex multiplication function. This one is specially
designed for CQT usage.
Parameters
----------
cqt_filter : tuple of torch.Tensor
The tuple is in the format of ``(real_torch_tensor, imag_torch_tensor)``
Returns
-------
tuple of torch.Tensor
The output is in the format of ``(real_torch_tensor, imag_torch_tensor)``
"""
cqt_filter_real = cqt_filter[0]
cqt_filter_imag = cqt_filter[1]
fourier_real = stft[0]
fourier_imag = stft[1]
CQT_real = torch.matmul(cqt_filter_real, fourier_real) - torch.matmul(cqt_filter_imag, fourier_imag)
CQT_imag = torch.matmul(cqt_filter_real, fourier_imag) + torch.matmul(cqt_filter_imag, fourier_real)
return CQT_real, CQT_imag
def broadcast_dim(x):
"""
Auto broadcast input so that it can fits into a Conv1d
"""
if x.dim() == 2:
x = x[:, None, :]
elif x.dim() == 1:
# If nn.DataParallel is used, this broadcast doesn't work
x = x[None, None, :]
elif x.dim() == 3:
pass
else:
raise ValueError("Only support input with shape = (batch, len) or shape = (len)")
return x
def broadcast_dim_conv2d(x):
"""
Auto broadcast input so that it can fits into a Conv2d
"""
if x.dim() == 3:
x = x[:, None, :,:]
else:
raise ValueError("Only support input with shape = (batch, len) or shape = (len)")
return x
## Kernal generation functions ##
def create_fourier_kernels(n_fft, win_length=None, freq_bins=None, fmin=50,fmax=6000, sr=44100,
freq_scale='linear', window='hann', verbose=True):
""" This function creates the Fourier Kernel for STFT, Melspectrogram and CQT.
Most of the parameters follow librosa conventions. Part of the code comes from
pytorch_musicnet. https://github.com/jthickstun/pytorch_musicnet
Parameters
----------
n_fft : int
The window size
freq_bins : int
Number of frequency bins. Default is ``None``, which means ``n_fft//2+1`` bins
fmin : int
The starting frequency for the lowest frequency bin.
If freq_scale is ``no``, this argument does nothing.
fmax : int
The ending frequency for the highest frequency bin.
If freq_scale is ``no``, this argument does nothing.
sr : int
The sampling rate for the input audio. It is used to calculate the correct ``fmin`` and ``fmax``.
Setting the correct sampling rate is very important for calculating the correct frequency.
freq_scale: 'linear', 'log', or 'no'
Determine the spacing between each frequency bin.
When 'linear' or 'log' is used, the bin spacing can be controlled by ``fmin`` and ``fmax``.
If 'no' is used, the bin will start at 0Hz and end at Nyquist frequency with linear spacing.
Returns
-------
wsin : numpy.array
Imaginary Fourier Kernel with the shape ``(freq_bins, 1, n_fft)``
wcos : numpy.array
Real Fourier Kernel with the shape ``(freq_bins, 1, n_fft)``
bins2freq : list
Mapping each frequency bin to frequency in Hz.
binslist : list
The normalized frequency ``k`` in digital domain.
This ``k`` is in the Discrete Fourier Transform equation $$
"""
if freq_bins==None: freq_bins = n_fft//2+1
if win_length==None: win_length = n_fft
s = np.arange(0, n_fft, 1.)
wsin = np.empty((freq_bins,1,n_fft))
wcos = np.empty((freq_bins,1,n_fft))
start_freq = fmin
end_freq = fmax
bins2freq = []
binslist = []
# num_cycles = start_freq*d/44000.
# scaling_ind = np.log(end_freq/start_freq)/k
# Choosing window shape
window_mask = get_window(window,int(win_length), fftbins=True)
window_mask = pad_center(window_mask, n_fft)
if freq_scale == 'linear':
if verbose==True:
print(f"sampling rate = {sr}. Please make sure the sampling rate is correct in order to"
f"get a valid freq range")
start_bin = start_freq*n_fft/sr
scaling_ind = (end_freq-start_freq)*(n_fft/sr)/freq_bins
for k in range(freq_bins): # Only half of the bins contain useful info
# print("linear freq = {}".format((k*scaling_ind+start_bin)*sr/n_fft))
bins2freq.append((k*scaling_ind+start_bin)*sr/n_fft)
binslist.append((k*scaling_ind+start_bin))
wsin[k,0,:] = np.sin(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)
wcos[k,0,:] = np.cos(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)
elif freq_scale == 'log':
if verbose==True:
print(f"sampling rate = {sr}. Please make sure the sampling rate is correct in order to"
f"get a valid freq range")
start_bin = start_freq*n_fft/sr
scaling_ind = np.log(end_freq/start_freq)/freq_bins
for k in range(freq_bins): # Only half of the bins contain useful info
# print("log freq = {}".format(np.exp(k*scaling_ind)*start_bin*sr/n_fft))
bins2freq.append(np.exp(k*scaling_ind)*start_bin*sr/n_fft)
binslist.append((np.exp(k*scaling_ind)*start_bin))
wsin[k,0,:] = np.sin(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)
wcos[k,0,:] = np.cos(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)
elif freq_scale == 'no':
for k in range(freq_bins): # Only half of the bins contain useful info
bins2freq.append(k*sr/n_fft)
binslist.append(k)
wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)
wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)
else:
print("Please select the correct frequency scale, 'linear' or 'log'")
return wsin.astype(np.float32),wcos.astype(np.float32), bins2freq, binslist, window_mask.astype(np.float32)
# Tools for CQT
def create_cqt_kernels(Q, fs, fmin, n_bins=84, bins_per_octave=12, norm=1,
window='hann', fmax=None, topbin_check=True):
"""
Automatically create CQT kernels in time domain
"""
fftLen = 2**nextpow2(np.ceil(Q * fs / fmin))
# minWin = 2**nextpow2(np.ceil(Q * fs / fmax))
if (fmax != None) and (n_bins == None):
n_bins = np.ceil(bins_per_octave * np.log2(fmax / fmin)) # Calculate the number of bins
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave))
elif (fmax == None) and (n_bins != None):
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave))
else:
warnings.warn('If fmax is given, n_bins will be ignored',SyntaxWarning)
n_bins = np.ceil(bins_per_octave * np.log2(fmax / fmin)) # Calculate the number of bins
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave))
if np.max(freqs) > fs/2 and topbin_check==True:
raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, \
please reduce the n_bins'.format(np.max(freqs)))
tempKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64)
specKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64)
lengths = np.ceil(Q * fs / freqs)
for k in range(0, int(n_bins)):
freq = freqs[k]
l = np.ceil(Q * fs / freq)
# Centering the kernels
if l%2==1: # pad more zeros on RHS
start = int(np.ceil(fftLen / 2.0 - l / 2.0))-1
else:
start = int(np.ceil(fftLen / 2.0 - l / 2.0))
sig = get_window_dispatch(window,int(l), fftbins=True)*np.exp(np.r_[-l//2:l//2]*1j*2*np.pi*freq/fs)/l
if norm: # Normalizing the filter # Trying to normalize like librosa
tempKernel[k, start:start + int(l)] = sig/np.linalg.norm(sig, norm)
else:
tempKernel[k, start:start + int(l)] = sig
# specKernel[k, :] = fft(tempKernel[k])
# return specKernel[:,:fftLen//2+1], fftLen, torch.tensor(lenghts).float()
return tempKernel, fftLen, torch.tensor(lengths).float(), freqs
def get_window_dispatch(window, N, fftbins=True):
if isinstance(window, str):
return get_window(window, N, fftbins=fftbins)
elif isinstance(window, tuple):
if window[0] == 'gaussian':
assert window[1] >= 0
sigma = np.floor(- N / 2 / np.sqrt(- 2 * np.log(10**(- window[1] / 20))))
return get_window(('gaussian', sigma), N, fftbins=fftbins)
else:
Warning("Tuple windows may have undesired behaviour regarding Q factor")
elif isinstance(window, float):
Warning("You are using Kaiser window with beta factor " + str(window) + ". Correct behaviour not checked.")
else:
raise Exception("The function get_window from scipy only supports strings, tuples and floats.")
def get_cqt_complex(x, cqt_kernels_real, cqt_kernels_imag, hop_length, padding):
"""Multiplying the STFT result with the cqt_kernel, check out the 1992 CQT paper [1]
for how to multiple the STFT result with the CQT kernel
[2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of
a constant Q transform.” (1992)."""
# STFT, converting the audio input from time domain to frequency domain
try:
x = padding(x) # When center == True, we need padding at the beginning and ending
except:
warnings.warn(f"\ninput size = {x.shape}\tkernel size = {cqt_kernels_real.shape[-1]}\n"
"padding with reflection mode might not be the best choice, try using constant padding",
UserWarning)
x = torch.nn.functional.pad(x, (cqt_kernels_real.shape[-1]//2, cqt_kernels_real.shape[-1]//2))
CQT_real = conv1d(x, cqt_kernels_real, stride=hop_length)
CQT_imag = -conv1d(x, cqt_kernels_imag, stride=hop_length)
return torch.stack((CQT_real, CQT_imag),-1)
def get_cqt_complex2(x, cqt_kernels_real, cqt_kernels_imag, hop_length, padding, wcos=None, wsin=None):
"""Multiplying the STFT result with the cqt_kernel, check out the 1992 CQT paper [1]
for how to multiple the STFT result with the CQT kernel
[2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of
a constant Q transform.” (1992)."""
# STFT, converting the audio input from time domain to frequency domain
try:
x = padding(x) # When center == True, we need padding at the beginning and ending
except:
warnings.warn(f"\ninput size = {x.shape}\tkernel size = {cqt_kernels_real.shape[-1]}\n"
"padding with reflection mode might not be the best choice, try using constant padding",
UserWarning)
x = torch.nn.functional.pad(x, (cqt_kernels_real.shape[-1]//2, cqt_kernels_real.shape[-1]//2))
if wcos==None or wsin==None:
CQT_real = conv1d(x, cqt_kernels_real, stride=hop_length)
CQT_imag = -conv1d(x, cqt_kernels_imag, stride=hop_length)
else:
fourier_real = conv1d(x, wcos, stride=hop_length)
fourier_imag = conv1d(x, wsin, stride=hop_length)
# Multiplying input with the CQT kernel in freq domain
CQT_real, CQT_imag = complex_mul((cqt_kernels_real, cqt_kernels_imag),
(fourier_real, fourier_imag))
return torch.stack((CQT_real, CQT_imag),-1)
def create_lowpass_filter(band_center=0.5, kernelLength=256, transitionBandwidth=0.03):
"""
Calculate the highest frequency we need to preserve and the lowest frequency we allow
to pass through.
Note that frequency is on a scale from 0 to 1 where 0 is 0 and 1 is Nyquist frequency of
the signal BEFORE downsampling.
"""
# transitionBandwidth = 0.03
passbandMax = band_center / (1 + transitionBandwidth)
stopbandMin = band_center * (1 + transitionBandwidth)
# Unlike the filter tool we used online yesterday, this tool does
# not allow us to specify how closely the filter matches our
# specifications. Instead, we specify the length of the kernel.
# The longer the kernel is, the more precisely it will match.
# kernelLength = 256
# We specify a list of key frequencies for which we will require
# that the filter match a specific output gain.
# From [0.0 to passbandMax] is the frequency range we want to keep
# untouched and [stopbandMin, 1.0] is the range we want to remove
keyFrequencies = [0.0, passbandMax, stopbandMin, 1.0]
# We specify a list of output gains to correspond to the key
# frequencies listed above.
# The first two gains are 1.0 because they correspond to the first
# two key frequencies. the second two are 0.0 because they
# correspond to the stopband frequencies
gainAtKeyFrequencies = [1.0, 1.0, 0.0, 0.0]
# This command produces the filter kernel coefficients
filterKernel = signal.firwin2(kernelLength, keyFrequencies, gainAtKeyFrequencies)
return filterKernel.astype(np.float32)
def get_early_downsample_params(sr, hop_length, fmax_t, Q, n_octaves, verbose):
"""Used in CQT2010 and CQT2010v2"""
window_bandwidth = 1.5 # for hann window
filter_cutoff = fmax_t * (1 + 0.5 * window_bandwidth / Q)
sr, hop_length, downsample_factor = early_downsample(sr,
hop_length,
n_octaves,
sr//2,
filter_cutoff)
if downsample_factor != 1:
if verbose==True:
print("Can do early downsample, factor = ", downsample_factor)
earlydownsample=True
# print("new sr = ", sr)
# print("new hop_length = ", hop_length)
early_downsample_filter = create_lowpass_filter(band_center=1/downsample_factor,
kernelLength=256,
transitionBandwidth=0.03)
early_downsample_filter = torch.tensor(early_downsample_filter)[None, None, :]
else:
if verbose==True:
print("No early downsampling is required, downsample_factor = ", downsample_factor)
early_downsample_filter = None
earlydownsample=False
return sr, hop_length, downsample_factor, early_downsample_filter, earlydownsample
def early_downsample(sr, hop_length, n_octaves,
nyquist, filter_cutoff):
'''Return new sampling rate and hop length after early dowansampling'''
downsample_count = early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves)
# print("downsample_count = ", downsample_count)
downsample_factor = 2**(downsample_count)
hop_length //= downsample_factor # Getting new hop_length
new_sr = sr / float(downsample_factor) # Getting new sampling rate
sr = new_sr
return sr, hop_length, downsample_factor
# The following two downsampling count functions are obtained from librosa CQT
# They are used to determine the number of pre resamplings if the starting and ending frequency
# are both in low frequency regions.
def early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves):
'''Compute the number of early downsampling operations'''
downsample_count1 = max(0, int(np.ceil(np.log2(0.85 * nyquist /
filter_cutoff)) - 1) - 1)
# print("downsample_count1 = ", downsample_count1)
num_twos = nextpow2(hop_length)
downsample_count2 = max(0, num_twos - n_octaves + 1)
# print("downsample_count2 = ",downsample_count2)
return min(downsample_count1, downsample_count2)
def early_downsample(sr, hop_length, n_octaves,
nyquist, filter_cutoff):
'''Return new sampling rate and hop length after early dowansampling'''
downsample_count = early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves)
# print("downsample_count = ", downsample_count)
downsample_factor = 2**(downsample_count)
hop_length //= downsample_factor # Getting new hop_length
new_sr = sr / float(downsample_factor) # Getting new sampling rate
sr = new_sr
return sr, hop_length, downsample_factor
\ No newline at end of file
import setuptools
import codecs
import os.path
def read(rel_path):
here = os.path.abspath(os.path.dirname(__file__))
with codecs.open(os.path.join(here, rel_path), 'r') as fp:
return fp.read()
def get_version(rel_path):
for line in read(rel_path).splitlines():
if line.startswith('__version__'):
delim = '"' if '"' in line else "'"
return line.split(delim)[1]
else:
raise RuntimeError("Unable to find version string.")
setuptools.setup(
name="nnAudio", # Replace with your own username
version=get_version("nnAudio/__init__.py"),
author="KinWaiCheuk",
author_email="u3500684@connect.hku.hk",
description="A fast GPU audio processing toolbox with 1D convolutional neural network",
long_description='',
long_description_content_type="text/markdown",
url="https://github.com/KinWaiCheuk/nnAudio",
packages=setuptools.find_packages(),
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
python_requires='>=3.6',
)
# Creating parameters for STFT test
"""
It is equivalent to
[(1024, 128, 'ones'),
(1024, 128, 'hann'),
(1024, 128, 'hamming'),
(2048, 128, 'ones'),
(2048, 512, 'ones'),
(2048, 128, 'hann'),
(2048, 512, 'hann'),
(2048, 128, 'hamming'),
(2048, 512, 'hamming'),
(None, None, None)]
"""
stft_parameters = []
n_fft = [1024,2048]
hop_length = {128,512,1024}
window = ['ones', 'hann', 'hamming']
for i in n_fft:
for k in window:
for j in hop_length:
if j < (i/2):
stft_parameters.append((i,j,k))
stft_parameters.append((256, None, 'hann'))
stft_with_win_parameters = []
n_fft = [512,1024]
win_length = [400, 900]
hop_length = {128,256}
for i in n_fft:
for j in win_length:
if j < i:
for k in hop_length:
if k < (i/2):
stft_with_win_parameters.append((i,j,k))
mel_win_parameters = [(512,400), (1024, 1000)]
\ No newline at end of file
import pytest
import librosa
import torch
import matplotlib.pyplot as plt
from scipy.signal import chirp, sweep_poly
from nnAudio.Spectrogram import *
from parameters import *
gpu_idx=0
# librosa example audio for testing
example_y, example_sr = librosa.load(librosa.util.example_audio_file())
@pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_inverse2(n_fft, hop_length, window, device):
x = torch.tensor(example_y,device=device)
stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window).to(device)
istft = iSTFT(n_fft=n_fft, hop_length=hop_length, window=window).to(device)
X = stft(x.unsqueeze(0), output_format="Complex")
x_recon = istft(X, length=x.shape[0], onesided=True).squeeze()
assert np.allclose(x.cpu(), x_recon.cpu(), rtol=1e-5, atol=1e-3)
@pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_inverse(n_fft, hop_length, window, device):
x = torch.tensor(example_y, device=device)
stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window, iSTFT=True).to(device)
X = stft(x.unsqueeze(0), output_format="Complex")
x_recon = stft.inverse(X, length=x.shape[0]).squeeze()
assert np.allclose(x.cpu(), x_recon.cpu(), rtol=1e-3, atol=1)
# @pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters)
# def test_inverse_GPU(n_fft, hop_length, window):
# x = torch.tensor(example_y,device=f'cuda:{gpu_idx}')
# stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window, device=f'cuda:{gpu_idx}')
# X = stft(x.unsqueeze(0), output_format="Complex")
# x_recon = stft.inverse(X, num_samples=x.shape[0]).squeeze()
# assert np.allclose(x.cpu(), x_recon.cpu(), rtol=1e-3, atol=1)
@pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_stft_complex(n_fft, hop_length, window, device):
x = example_y
stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0), output_format="Complex")
X_real, X_imag = X[:, :, :, 0].squeeze(), X[:, :, :, 1].squeeze()
X_librosa = librosa.stft(x, n_fft=n_fft, hop_length=hop_length, window=window)
real_diff, imag_diff = np.allclose(X_real.cpu(), X_librosa.real, rtol=1e-3, atol=1e-3), \
np.allclose(X_imag.cpu(), X_librosa.imag, rtol=1e-3, atol=1e-3)
assert real_diff and imag_diff
# @pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters)
# def test_stft_complex_GPU(n_fft, hop_length, window):
# x = example_y
# stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window, device=f'cuda:{gpu_idx}')
# X = stft(torch.tensor(x,device=f'cuda:{gpu_idx}').unsqueeze(0), output_format="Complex")
# X_real, X_imag = X[:, :, :, 0].squeeze().detach().cpu(), X[:, :, :, 1].squeeze().detach().cpu()
# X_librosa = librosa.stft(x, n_fft=n_fft, hop_length=hop_length, window=window)
# real_diff, imag_diff = np.allclose(X_real, X_librosa.real, rtol=1e-3, atol=1e-3), \
# np.allclose(X_imag, X_librosa.imag, rtol=1e-3, atol=1e-3)
# assert real_diff and imag_diff
@pytest.mark.parametrize("n_fft, win_length, hop_length", stft_with_win_parameters)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_stft_complex_winlength(n_fft, win_length, hop_length, device):
x = example_y
stft = STFT(n_fft=n_fft, win_length=win_length, hop_length=hop_length).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0), output_format="Complex")
X_real, X_imag = X[:, :, :, 0].squeeze(), X[:, :, :, 1].squeeze()
X_librosa = librosa.stft(x, n_fft=n_fft, win_length=win_length, hop_length=hop_length)
real_diff, imag_diff = np.allclose(X_real.cpu(), X_librosa.real, rtol=1e-3, atol=1e-3), \
np.allclose(X_imag.cpu(), X_librosa.imag, rtol=1e-3, atol=1e-3)
assert real_diff and imag_diff
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_stft_magnitude(device):
x = example_y
stft = STFT(n_fft=2048, hop_length=512).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0), output_format="Magnitude").squeeze()
X_librosa, _ = librosa.core.magphase(librosa.stft(x, n_fft=2048, hop_length=512))
assert np.allclose(X.cpu(), X_librosa, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_stft_phase(device):
x = example_y
stft = STFT(n_fft=2048, hop_length=512).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0), output_format="Phase")
X_real, X_imag = torch.cos(X).squeeze(), torch.sin(X).squeeze()
_, X_librosa = librosa.core.magphase(librosa.stft(x, n_fft=2048, hop_length=512))
real_diff, imag_diff = np.mean(np.abs(X_real.cpu().numpy() - X_librosa.real)), \
np.mean(np.abs(X_imag.cpu().numpy() - X_librosa.imag))
# I find that np.allclose is too strict for allowing phase to be similar to librosa.
# Hence for phase we use average element-wise distance as the test metric.
assert real_diff < 2e-4 and imag_diff < 2e-4
@pytest.mark.parametrize("n_fft, win_length", mel_win_parameters)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_mel_spectrogram(n_fft, win_length, device):
x = example_y
melspec = MelSpectrogram(n_fft=n_fft, win_length=win_length, hop_length=512).to(device)
X = melspec(torch.tensor(x, device=device).unsqueeze(0)).squeeze()
X_librosa = librosa.feature.melspectrogram(x, n_fft=n_fft, win_length=win_length, hop_length=512)
assert np.allclose(X.cpu(), X_librosa, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_cqt_1992(device):
# Log sweep case
fs = 44100
t = 1
f0 = 55
f1 = 22050
s = np.linspace(0, t, fs*t)
x = chirp(s, f0, 1, f1, method='logarithmic')
x = x.astype(dtype=np.float32)
# Magnitude
stft = CQT1992(sr=fs, fmin=220, output_format="Magnitude",
n_bins=80, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
# Complex
stft = CQT1992(sr=fs, fmin=220, output_format="Complex",
n_bins=80, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
# Phase
stft = CQT1992(sr=fs, fmin=220, output_format="Phase",
n_bins=160, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
assert True
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_cqt_2010(device):
# Log sweep case
fs = 44100
t = 1
f0 = 55
f1 = 22050
s = np.linspace(0, t, fs*t)
x = chirp(s, f0, 1, f1, method='logarithmic')
x = x.astype(dtype=np.float32)
# Magnitude
stft = CQT2010(sr=fs, fmin=110, output_format="Magnitude",
n_bins=160, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
# Complex
stft = CQT2010(sr=fs, fmin=110, output_format="Complex",
n_bins=160, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
# Phase
stft = CQT2010(sr=fs, fmin=110, output_format="Phase",
n_bins=160, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
assert True
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_cqt_1992_v2_log(device):
# Log sweep case
fs = 44100
t = 1
f0 = 55
f1 = 22050
s = np.linspace(0, t, fs*t)
x = chirp(s, f0, 1, f1, method='logarithmic')
x = x.astype(dtype=np.float32)
# Magnitude
stft = CQT1992v2(sr=fs, fmin=55, output_format="Magnitude",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
ground_truth = np.load("tests/ground-truths/log-sweep-cqt-1992-mag-ground-truth.npy")
X = torch.log(X + 1e-5)
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
# Complex
stft = CQT1992v2(sr=fs, fmin=55, output_format="Complex",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
ground_truth = np.load("tests/ground-truths/log-sweep-cqt-1992-complex-ground-truth.npy")
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
# Phase
stft = CQT1992v2(sr=fs, fmin=55, output_format="Phase",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
ground_truth = np.load("tests/ground-truths/log-sweep-cqt-1992-phase-ground-truth.npy")
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_cqt_1992_v2_linear(device):
# Linear sweep case
fs = 44100
t = 1
f0 = 55
f1 = 22050
s = np.linspace(0, t, fs*t)
x = chirp(s, f0, 1, f1, method='linear')
x = x.astype(dtype=np.float32)
# Magnitude
stft = CQT1992v2(sr=fs, fmin=55, output_format="Magnitude",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-1992-mag-ground-truth.npy")
X = torch.log(X + 1e-5)
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
# Complex
stft = CQT1992v2(sr=fs, fmin=55, output_format="Complex",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-1992-complex-ground-truth.npy")
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
# Phase
stft = CQT1992v2(sr=fs, fmin=55, output_format="Phase",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-1992-phase-ground-truth.npy")
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_cqt_2010_v2_log(device):
# Log sweep case
fs = 44100
t = 1
f0 = 55
f1 = 22050
s = np.linspace(0, t, fs*t)
x = chirp(s, f0, 1, f1, method='logarithmic')
x = x.astype(dtype=np.float32)
# Magnitude
stft = CQT2010v2(sr=fs, fmin=55, output_format="Magnitude",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
X = torch.log(X + 1e-2)
# np.save("tests/ground-truths/log-sweep-cqt-2010-mag-ground-truth", X.cpu())
ground_truth = np.load("tests/ground-truths/log-sweep-cqt-2010-mag-ground-truth.npy")
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
# Complex
stft = CQT2010v2(sr=fs, fmin=55, output_format="Complex",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
# np.save("tests/ground-truths/log-sweep-cqt-2010-complex-ground-truth", X.cpu())
ground_truth = np.load("tests/ground-truths/log-sweep-cqt-2010-complex-ground-truth.npy")
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
# # Phase
# stft = CQT2010v2(sr=fs, fmin=55, device=device, output_format="Phase",
# n_bins=207, bins_per_octave=24)
# X = stft(torch.tensor(x, device=device).unsqueeze(0))
# # np.save("tests/ground-truths/log-sweep-cqt-2010-phase-ground-truth", X.cpu())
# ground_truth = np.load("tests/ground-truths/log-sweep-cqt-2010-phase-ground-truth.npy")
# assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_cqt_2010_v2_linear(device):
# Linear sweep case
fs = 44100
t = 1
f0 = 55
f1 = 22050
s = np.linspace(0, t, fs*t)
x = chirp(s, f0, 1, f1, method='linear')
x = x.astype(dtype=np.float32)
# Magnitude
stft = CQT2010v2(sr=fs, fmin=55, output_format="Magnitude",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
X = torch.log(X + 1e-2)
# np.save("tests/ground-truths/linear-sweep-cqt-2010-mag-ground-truth", X.cpu())
ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-2010-mag-ground-truth.npy")
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
# Complex
stft = CQT2010v2(sr=fs, fmin=55, output_format="Complex",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
# np.save("tests/ground-truths/linear-sweep-cqt-2010-complex-ground-truth", X.cpu())
ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-2010-complex-ground-truth.npy")
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
# Phase
# stft = CQT2010v2(sr=fs, fmin=55, device=device, output_format="Phase",
# n_bins=207, bins_per_octave=24)
# X = stft(torch.tensor(x, device=device).unsqueeze(0))
# # np.save("tests/ground-truths/linear-sweep-cqt-2010-phase-ground-truth", X.cpu())
# ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-2010-phase-ground-truth.npy")
# assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_mfcc(device):
x = example_y
mfcc = MFCC(sr=example_sr).to(device)
X = mfcc(torch.tensor(x, device=device).unsqueeze(0)).squeeze()
X_librosa = librosa.feature.mfcc(x, sr=example_sr)
assert np.allclose(X.cpu(), X_librosa, rtol=1e-3, atol=1e-3)
x = torch.randn((4,44100)) # Create a batch of input for the following Data.Parallel test
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
def test_STFT_Parallel(device):
spec_layer = STFT(hop_length=512, n_fft=2048, window='hann',
freq_scale='no',
output_format='Complex').to(device)
inverse_spec_layer = iSTFT(hop_length=512, n_fft=2048, window='hann',
freq_scale='no').to(device)
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
inverse_spec_layer_parallel = torch.nn.DataParallel(inverse_spec_layer)
spec = spec_layer_parallel(x)
x_recon = inverse_spec_layer_parallel(spec, onesided=True, length=x.shape[-1])
assert np.allclose(x_recon.detach().cpu(), x.detach().cpu(), rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
def test_MelSpectrogram_Parallel(device):
spec_layer = MelSpectrogram(sr=22050, n_fft=2048, n_mels=128, hop_length=512,
window='hann', center=True, pad_mode='reflect',
power=2.0, htk=False, fmin=0.0, fmax=None, norm=1,
verbose=True).to(device)
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
spec = spec_layer_parallel(x)
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
def test_MFCC_Parallel(device):
spec_layer = MFCC().to(device)
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
spec = spec_layer_parallel(x)
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
def test_CQT1992_Parallel(device):
spec_layer = CQT1992(fmin=110, n_bins=60, bins_per_octave=12).to(device)
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
spec = spec_layer_parallel(x)
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
def test_CQT1992v2_Parallel(device):
spec_layer = CQT1992v2().to(device)
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
spec = spec_layer_parallel(x)
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
def test_CQT2010_Parallel(device):
spec_layer = CQT2010().to(device)
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
spec = spec_layer_parallel(x)
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
def test_CQT2010v2_Parallel(device):
spec_layer = CQT2010v2().to(device)
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
spec = spec_layer_parallel(x)
\ No newline at end of file
from typing import Tuple
import numpy as np
import paddle
from paddle import Tensor
from paddle import nn
from paddle.nn import functional as F
def frame(x: Tensor,
num_samples: Tensor,
win_length: int,
hop_length: int,
clip: bool = True) -> Tuple[Tensor, Tensor]:
"""Extract frames from audio.
Parameters
----------
x : Tensor
Shape (N, T), batched waveform.
num_samples : Tensor
Shape (N, ), number of samples of each waveform.
win_length : int
Window length.
hop_length : int
Number of samples shifted between ajancent frames.
clip : bool, optional
Whether to clip audio that does not fit into the last frame, by
default True
Returns
-------
frames : Tensor
Shape (N, T', win_length).
num_frames : Tensor
Shape (N, ) number of valid frames
"""
assert hop_length <= win_length
num_frames = (num_samples - win_length) // hop_length
padding = (0, 0)
if not clip:
num_frames += 1
# NOTE: pad hop_length - 1 to the right to ensure that there is at most
# one frame dangling to the righe edge
padding = (0, hop_length - 1)
weight = paddle.eye(win_length).unsqueeze(1)
frames = F.conv1d(x.unsqueeze(1),
weight,
padding=padding,
stride=(hop_length, ))
return frames, num_frames
class STFT(nn.Layer):
"""A module for computing stft transformation in a differentiable way.
Parameters
------------
n_fft : int
Number of samples in a frame.
hop_length : int
Number of samples shifted between adjacent frames.
win_length : int
Length of the window.
clip: bool
Whether to clip audio is necesaary.
"""
def __init__(self,
n_fft: int,
hop_length: int,
win_length: int,
window_type: str = None,
clip: bool = True):
super().__init__()
self.hop_length = hop_length
self.n_bin = 1 + n_fft // 2
self.n_fft = n_fft
self.clip = clip
# calculate window
if window_type is None:
window = np.ones(win_length)
elif window_type == "hann":
window = np.hanning(win_length)
elif window_type == "hamming":
window = np.hamming(win_length)
else:
raise ValueError("Not supported yet!")
if win_length < n_fft:
window = F.pad(window, (0, n_fft - win_length))
elif win_length > n_fft:
window = window[:n_fft]
# (n_bins, n_fft) complex
kernel_size = min(n_fft, win_length)
weight = np.fft.fft(np.eye(n_fft))[:self.n_bin, :kernel_size]
w_real = weight.real
w_imag = weight.imag
# (2 * n_bins, kernel_size)
w = np.concatenate([w_real, w_imag], axis=0)
w = w * window
# (2 * n_bins, 1, kernel_size) # (C_out, C_in, kernel_size)
w = np.expand_dims(w, 1)
weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype())
self.register_buffer("weight", weight)
def forward(self, x: Tensor, num_samples: Tensor) -> Tuple[Tensor, Tensor]:
"""Compute the stft transform.
Parameters
------------
x : Tensor [shape=(B, T)]
The input waveform.
num_samples : Tensor
Number of samples of each waveform.
Returns
------------
D : Tensor
Shape(N, T', n_bins, 2) Spectrogram.
num_frames: Tensor
Shape (N,) number of samples of each spectrogram
"""
num_frames = (num_samples - self.win_length) // self.hop_length
padding = (0, 0)
if not self.clip:
num_frames += 1
padding = (0, self.hop_length - 1)
batch_size, _, _ = paddle.shape(x)
x = x.unsqueeze(-1)
D = F.conv1d(self.weight,
x,
stride=(self.hop_length, ),
padding=padding,
data_format="NLC")
D = paddle.reshape(D, [batch_size, -1, self.n_bin, 2])
return D, num_frames
from .sentence_split import split
from .num import RE_NUMBER, RE_FRAC, RE_PERCENTAGE, RE_RANGE, RE_INTEGER, RE_DEFAULT_NUM
from .num import replace_number, replace_frac, replace_percentage, replace_range, replace_default_num
from .chronology import RE_TIME, RE_DATE, RE_DATE2
from .chronology import replace_time, replace_date, replace_date2
from .quantifier import RE_TEMPERATURE
from .quantifier import replace_temperature
from .phone import RE_MOBILE_PHONE, RE_TELEPHONE, replace_phone
from .char_convert import tranditional_to_simplified
from .constants import F2H_ASCII_LETTERS, F2H_DIGITS, F2H_SPACE
def normalize_sentence(sentence):
# basic character conversions
sentence = tranditional_to_simplified(sentence)
sentence = sentence.translate(F2H_ASCII_LETTERS).translate(
F2H_DIGITS).translate(F2H_SPACE)
# number related NSW verbalization
sentence = RE_DATE.sub(replace_date, sentence)
sentence = RE_DATE2.sub(replace_date2, sentence)
sentence = RE_TIME.sub(replace_time, sentence)
sentence = RE_TEMPERATURE.sub(replace_temperature, sentence)
sentence = RE_RANGE.sub(replace_range, sentence)
sentence = RE_FRAC.sub(replace_frac, sentence)
sentence = RE_PERCENTAGE.sub(replace_percentage, sentence)
sentence = RE_MOBILE_PHONE.sub(replace_phone, sentence)
sentence = RE_TELEPHONE.sub(replace_phone, sentence)
sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence)
sentence = RE_NUMBER.sub(replace_number, sentence)
return sentence
def normalize(text):
sentences = split(text)
sentences = [normalize_sentence(sent) for sent in sentences]
return sentences
"""Traditional and simplified Chinese conversion with
`opencc <https://github.com/BYVoid/OpenCC>`_.
"""
import opencc
_t2s_converter = opencc.OpenCC("t2s.json")
_s2t_converter = opencc.OpenCC('s2t.json')
def tranditional_to_simplified(text: str) -> str:
return _t2s_converter.convert(text)
def simplified_to_traditional(text: str) -> str:
return _s2t_converter.convert(text)
import re
from .num import verbalize_cardinal, verbalize_digit, num2str, DIGITS
def _time_num2str(num_string: str) -> str:
"""A special case for verbalizing number in time."""
result = num2str(num_string.lstrip('0'))
if num_string.startswith('0'):
result = DIGITS['0'] + result
return result
# 时刻表达式
RE_TIME = re.compile(
r'([0-1]?[0-9]|2[0-3])'
r':([0-5][0-9])'
r'(:([0-5][0-9]))?'
)
def replace_time(match: re.Match) -> str:
hour = match.group(1)
minute = match.group(2)
second = match.group(4)
result = f"{num2str(hour)}点"
if minute.lstrip('0'):
result += f"{_time_num2str(minute)}分"
if second and second.lstrip('0'):
result += f"{_time_num2str(second)}秒"
return result
RE_DATE = re.compile(
r'(\d{4}|\d{2})年'
r'((0?[1-9]|1[0-2])月)?'
r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?'
)
def replace_date(match: re.Match) -> str:
year = match.group(1)
month = match.group(3)
day = match.group(5)
result = ""
if year:
result += f"{verbalize_digit(year)}年"
if month:
result += f"{verbalize_cardinal(month)}月"
if day:
result += f"{verbalize_cardinal(day)}{match.group(9)}"
return result
# 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期
RE_DATE2 = re.compile(
r'(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])'
)
def replace_date2(match: re.Match) -> str:
year = match.group(1)
month = match.group(3)
day = match.group(4)
result = ""
if year:
result += f"{verbalize_digit(year)}年"
if month:
result += f"{verbalize_cardinal(month)}月"
if day:
result += f"{verbalize_cardinal(day)}日"
return result
import string
import re
from pypinyin.constants import SUPPORT_UCS4
# 全角半角转换
# 英文字符全角 -> 半角映射表 (num: 52)
F2H_ASCII_LETTERS = {
chr(ord(char) + 65248): char
for char in string.ascii_letters
}
# 英文字符半角 -> 全角映射表
H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()}
# 数字字符全角 -> 半角映射表 (num: 10)
F2H_DIGITS = {
chr(ord(char) + 65248): char
for char in string.digits
}
# 数字字符半角 -> 全角映射表
H2F_DIGITS = {value: key for key, value in F2H_DIGITS.items()}
# 标点符号全角 -> 半角映射表 (num: 32)
F2H_PUNCTUATIONS = {
chr(ord(char) + 65248): char
for char in string.punctuation
}
# 标点符号半角 -> 全角映射表
H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()}
# 空格 (num: 1)
F2H_SPACE = {'\u3000': ' '}
H2F_SPACE = {' ': '\u3000'}
# 非"有拼音的汉字"的字符串,可用于NSW提取
if SUPPORT_UCS4:
RE_NSW = re.compile(
r'(?:[^'
r'\u3007' # 〇
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF]
r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F]
r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D]
r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F]
r'])+'
)
else:
RE_NSW = re.compile( # pragma: no cover
r'(?:[^'
r'\u3007' # 〇
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
r'])+'
)
"""
Rules to verbalize numbers into Chinese characters.
https://zh.wikipedia.org/wiki/中文数字#現代中文
"""
import re
from typing import List
from collections import OrderedDict
DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')}
UNITS = OrderedDict({
1: '十',
2: '百',
3: '千',
4: '万',
8: '亿',
})
# 分数表达式
RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
def replace_frac(match: re.Match) -> str:
sign = match.group(1)
nominator = match.group(2)
denominator = match.group(3)
sign: str = "负" if sign else ""
nominator: str = num2str(nominator)
denominator: str = num2str(denominator)
result = f"{sign}{denominator}分之{nominator}"
return result
# 百分数表达式
RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%')
def replace_percentage(match: re.Match) -> str:
sign = match.group(1)
percent = match.group(2)
sign: str = "负" if sign else ""
percent: str = num2str(percent)
result = f"{sign}百分之{percent}"
return result
# 整数表达式
# 带负号或者不带负号的整数 12, -10
RE_INTEGER = re.compile(
r'(-?)'
r'(\d+)'
)
# 编号-无符号整形
# 00078
RE_DEFAULT_NUM = re.compile(r'\d{4}\d*')
def replace_default_num(match: re.Match):
number = match.group(0)
return verbalize_digit(number)
# 数字表达式
# 1. 整数: -10, 10;
# 2. 浮点数: 10.2, -0.3
# 3. 不带符号和整数部分的纯浮点数: .22, .38
RE_NUMBER = re.compile(
r'(-?)((\d+)(\.\d+)?)'
r'|(\.(\d+))'
)
def replace_number(match: re.Match) -> str:
sign = match.group(1)
number = match.group(2)
pure_decimal = match.group(5)
if pure_decimal:
result = num2str(pure_decimal)
else:
sign: str = "负" if sign else ""
number: str = num2str(number)
result = f"{sign}{number}"
return result
# 范围表达式
# 12-23, 12~23
RE_RANGE = re.compile(
r'(\d+)[-~](\d+)'
)
def replace_range(match: re.Match) -> str:
first, second = match.group(1), match.group(2)
first: str = num2str(first)
second: str = num2str(second)
result = f"{first}{second}"
return result
def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
stripped = value_string.lstrip('0')
if len(stripped) == 0:
return []
elif len(stripped) == 1:
if use_zero and len(stripped) < len(value_string):
return [DIGITS['0'], DIGITS[stripped]]
else:
return [DIGITS[stripped]]
else:
largest_unit = next(power for power in reversed(UNITS.keys()) if power < len(stripped))
first_part = value_string[:-largest_unit]
second_part = value_string[-largest_unit:]
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(second_part)
def verbalize_cardinal(value_string: str) -> str:
if not value_string:
return ''
# 000 -> '零' , 0 -> '零'
value_string = value_string.lstrip('0')
if len(value_string) == 0:
return DIGITS['0']
result_symbols = _get_value(value_string)
# verbalized number starting with '一十*' is abbreviated as `十*`
if len(result_symbols) >= 2 and result_symbols[0] == DIGITS['1'] and result_symbols[1] == UNITS[1]:
result_symbols = result_symbols[1:]
return ''.join(result_symbols)
def verbalize_digit(value_string: str, alt_one=False) -> str:
result_symbols = [DIGITS[digit] for digit in value_string]
result = ''.join(result_symbols)
if alt_one:
result.replace("一", "幺")
return result
def num2str(value_string: str) -> str:
integer_decimal = value_string.split('.')
if len(integer_decimal) == 1:
integer = integer_decimal[0]
decimal = ''
elif len(integer_decimal) == 2:
integer, decimal = integer_decimal
else:
raise ValueError(f"The value string: '${value_string}' has more than one point in it.")
result = verbalize_cardinal(integer)
decimal = decimal.rstrip('0')
if decimal:
# '.22' is verbalized as '点二二'
# '3.20' is verbalized as '三点二
result += '点' + verbalize_digit(decimal)
return result
import re
from .num import verbalize_digit
# 规范化固话/手机号码
# 手机
# http://www.jihaoba.com/news/show/13680
# 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
# 联通:130、131、132、156、155、186、185、176
# 电信:133、153、189、180、181、177
RE_MOBILE_PHONE= re.compile(
r"(?<!\d)((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})(?!\d)")
RE_TELEPHONE = re.compile(
r"(?<!\d)((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})(?!\d)")
def phone2str(phone_string: str, mobile=True) -> str:
if mobile:
sp_parts = phone_string.strip('+').split()
result = ''.join(
[verbalize_digit(part, alt_one=True) for part in sp_parts])
return result
else:
sil_parts = phone_string.split('-')
result = ''.join(
[verbalize_digit(part, alt_one=True) for part in sil_parts])
return result
def replace_phone(match: re.Match) -> str:
return phone2str(match.group(0))
import re
from .num import num2str
# 温度表达式,温度会影响负号的读法
# -3°C 零下三度
RE_TEMPERATURE = re.compile(
r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)'
)
def replace_temperature(match: re.Match) -> str:
sign = match.group(1)
temperature = match.group(2)
unit = match.group(3)
sign: str = "零下" if sign else ""
temperature: str = num2str(temperature)
unit: str = "摄氏度" if unit == "摄氏度" else "度"
result = f"{sign}{temperature}{unit}"
return result
import re
from typing import List
SENTENCE_SPLITOR = re.compile(r'([。!?][”’]?)')
def split(text: str) -> List[str]:
"""Split long text into sentences with sentence-splitting punctuations.
Parameters
----------
text : str
The input text.
Returns
-------
List[str]
Sentences.
"""
text = SENTENCE_SPLITOR.sub(r'\1\n', text)
text = text.strip()
sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
return sentences
SHELL:= /bin/bash
PYTHON:= python3.7
.PHONY: all clean
all: virtualenv kenlm.done sox.done soxbindings.done
all: virtualenv kenlm.done sox.done soxbindings.done mfa.done
virtualenv:
test -d venv || virtualenv -p $(PYTHON) venv
......@@ -18,8 +19,8 @@ kenlm.done:
apt install -y build-essential cmake libboost-system-dev libboost-thread-dev libboost-program-options-dev libboost-test-dev libeigen3-dev zlib1g-dev libbz2-dev liblzma-dev
apt-get install -y gcc-5 g++-5 && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-5 50 && update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-5 50
test -d kenlm || wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz
mkdir -p kenlm/build && cd kenlm/build && cmake .. && make -j4 && make install
cd kenlm && python setup.py install
rm -rf kenlm/build && mkdir -p kenlm/build && cd kenlm/build && cmake .. && make -j4 && make install
source venv/bin/activate; cd kenlm && python setup.py install
touch kenlm.done
sox.done:
......@@ -31,5 +32,10 @@ sox.done:
soxbindings.done:
test -d soxbindings || git clone https://github.com/pseeth/soxbindings.git
source venv/bin/activate; cd soxbindings && python3 setup.py install
source venv/bin/activate; cd soxbindings && python setup.py install
touch soxbindings.done
mfa.done:
test -d montreal-forced-aligner || wget https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/download/v1.0.1/montreal-forced-aligner_linux.tar.gz
tar xvf montreal-forced-aligner_linux.tar.gz
touch mfa.done
1. kaldi
deps gcc, mkl or openblas
2. OpenFST/ngram/pynini
deps gcc
3. MFA
deps kaldi
#!/bin/bash
set -e
set -x
# gcc
apt update -y
apt install build-essential -y
apt install software-properties-common -y
add-apt-repository ppa:ubuntu-toolchain-r/test
apt install gcc-8 g++-8 -y
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-8 80
update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 80
update-alternatives --config gcc
# gfortran
apt-get install gfortran-8
#!/bin/bash
# Installation script for Kaldi
#
set -e
apt-get install subversion -y
KALDI_GIT="--depth 1 -b master https://github.com/kaldi-asr/kaldi.git"
KALDI_DIR="$PWD/kaldi"
if [ ! -d "$KALDI_DIR" ]; then
git clone $KALDI_GIT $KALDI_DIR
else
echo "$KALDI_DIR already exists!"
fi
cd "$KALDI_DIR/tools"
git pull
# Prevent kaldi from switching default python version
mkdir -p "python"
touch "python/.use_default_python"
./extras/check_dependencies.sh
make -j4
pushd ../src
./configure --shared --use-cuda=no --static-math --mathlib=OPENBLAS --openblas-root=${KALDI_DIR}/../OpenBLAS/install
make clean -j && make depend -j && make -j4
popd
echo "Done installing Kaldi."
#!/bin/bash
apt install -y build-essential cmake libboost-system-dev libboost-thread-dev libboost-program-options-dev libboost-test-dev libeigen3-dev zlib1g-dev libbz2-dev liblzma-dev
apt-get install -y gcc-5 g++-5 && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-5 50 && update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-5 50
test -d kenlm || wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz
rm -rf kenlm/build && mkdir -p kenlm/build && cd kenlm/build && cmake .. && make -j4 && make install
#!/usr/bin/env bash
VER=1.10
WGET=${WGET:-wget}
if [ ! -f liblbfgs-$VER.tar.gz ]; then
if [ -d "$DOWNLOAD_DIR" ]; then
cp -p "$DOWNLOAD_DIR/liblbfgs-$VER.tar.gz" . || exit 1
else
$WGET https://github.com/downloads/chokkan/liblbfgs/liblbfgs-$VER.tar.gz || exit 1
fi
fi
tar -xzf liblbfgs-$VER.tar.gz
cd liblbfgs-$VER
./configure --prefix=`pwd`
make
# due to the liblbfgs project directory structure, we have to use -i
# but the erros are completely harmless
make -i install
cd ..
(
[ ! -z "${LIBLBFGS}" ] && \
echo >&2 "LIBLBFGS variable is aleady defined. Undefining..." && \
unset LIBLBFGS
[ -f ./env.sh ] && . ./env.sh
[ ! -z "${LIBLBFGS}" ] && \
echo >&2 "libLBFGS config is already in env.sh" && exit
wd=`pwd`
wd=`readlink -f $wd || pwd`
echo "export LIBLBFGS=$wd/liblbfgs-1.10"
echo export LD_LIBRARY_PATH='${LD_LIBRARY_PATH:-}':'${LIBLBFGS}'/lib/.libs
) >> env.sh
#!/bin/bash
# install openblas, kaldi before
test -d Montreal-Forced-Aligner || git clone https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner.git
pushd Montreal-Forced-Aligner && git checkout v2.0.0a7 && python setup.py install
test -d kaldi || { echo "need install kaldi first"; exit 1;}
mfa thirdparty kaldi $PWD/kaldi
mfa thirdparty validate
echo "install mfa pass."
#!/usr/bin/env bash
WGET=${WGET:-wget}
# The script automatically choose default settings of miniconda for installation
# Miniconda will be installed in the HOME directory. ($HOME/miniconda3).
# Also don't make miniconda's python as default.
if [ -d "$DOWNLOAD_DIR" ]; then
cp -p "$DOWNLOAD_DIR/Miniconda3-latest-Linux-x86_64.sh" . || exit 1
else
$WGET https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh || exit 1
fi
bash Miniconda3-latest-Linux-x86_64.sh -b
$HOME/miniconda3/bin/python -m pip install --user tqdm
$HOME/miniconda3/bin/python -m pip install --user scikit-learn
$HOME/miniconda3/bin/python -m pip install --user librosa
$HOME/miniconda3/bin/python -m pip install --user h5py
#!/usr/bin/env bash
# Intel MKL is now freely available even for commercial use. This script
# attempts to install the MKL package automatically from Intel's repository.
#
# For manual repository setup instructions, see:
# https://software.intel.com/articles/installing-intel-free-libs-and-python-yum-repo
# https://software.intel.com/articles/installing-intel-free-libs-and-python-apt-repo
#
# For other package managers, or non-Linux platforms, see:
# https://software.intel.com/mkl/choose-download
set -o pipefail
default_package=intel-mkl-64bit-2020.0-088
yum_repo='https://yum.repos.intel.com/mkl/setup/intel-mkl.repo'
apt_repo='https://apt.repos.intel.com/mkl'
intel_key_url='https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2019.PUB'
Usage () {
cat >&2 <<EOF
Usage: $0 [-s] [<MKL-package>]
Checks if MKL is present on the system, and/or attempts to install it.
If <MKL-package> is not provided, ${default_package} will be installed.
Intel packages are installed under the /opt/intel directory. You should be root
to install MKL into this directory; run this script using the sudo command.
Options:
-s - Skip check for MKL being already present.
-p <suse|redhat|debian|fedora|arch> -- Force type of package management. Use only
if automatic detection fails, as instructed.
-h - Show this message.
Environment:
CC The C compiler to use for MKL check. If not set, uses 'cc'.
EOF
exit 2
}
Fatal () { echo "$0: $@"; exit 1; }
Have () { type -t "$1" >/dev/null; }
# Option values.
skip_cc=
distro=
while getopts ":hksp:" opt; do
case ${opt} in
h) Usage ;;
s) skip_cc=yes ;;
p) case $OPTARG in
suse|redhat|debian|fedora|arch) distro=$OPTARG ;;
*) Fatal "invalid value -p '${OPTARG}'. " \
"Allowed: 'suse', 'redhat', 'debian', 'fedora', or 'arch'."
esac ;;
\?) echo >&2 "$0: invalid option -${OPTARG}."; Usage ;;
esac
done
shift $((OPTIND-1))
orig_arg_package=${1-''}
package=${1:-$default_package}
# Check that we are actually on Linux, otherwise give a helpful reference.
[[ $(uname) == Linux ]] || Fatal "\
This script can be used on Linux only, and your system is $(uname).
Installer packages for Mac and Windows are available for download from Intel:
https://software.intel.com/mkl/choose-download"
# Test if MKL is already installed on the system.
if [[ ! $skip_cc ]]; then
: ${CC:=cc}
Have "$CC" || Fatal "\
C compiler $CC not found.
You can skip the check for MKL presence by invoking this script with the '-s'
option to this script, but you will need a functional compiler anyway, so we
recommend that you install it first."
mkl_version=$($CC -E -I /opt/intel/mkl/include - <<< \
'#include <mkl_version.h>
__INTEL_MKL__.__INTEL_MKL_MINOR__.__INTEL_MKL_UPDATE__' 2>/dev/null |
tail -n 1 ) || mkl_version=
mkl_version=${mkl_version// /}
[[ $mkl_version ]] && Fatal "\
MKL version $mkl_version is already installed.
You can skip the check for MKL presence by invoking this script with the '-s'
option and proceed with automated installation, but we highly discourage
this. This script will register Intel repositories with your system, and it
seems that they have been already registered, or MKL has been installed some
other way.
You should use your package manager to check which MKL package is already
installed. Note that Intel packages register the latest installed version of
the library as the default. If your installed version is older than
$package, it makes sense to upgrade."
fi
# Try to determine which package manager the distro uses, unless overridden.
if [[ ! $distro ]]; then
dist_vars=$(cat /etc/os-release 2>/dev/null)
eval "$dist_vars"
for rune in $CPE_NAME $ID $ID_LIKE; do
case "$rune" in
cpe:/o:fedoraproject:fedora:2[01]) distro=redhat; break;; # Use yum.
rhel|centos) distro=redhat; break;;
redhat|suse|fedora|debian|arch) distro=$rune; break;;
esac
done
# Certain old distributions do not have /etc/os-release. We are unlikely to
# encounter these in the wild, but just in case.
# NOTE: Do not try to guess Fedora specifically here! Fedora 20 and below
# detect as redhat, and this is good, because they use yum by default.
[[ ! $distro && -f /etc/redhat-release ]] && distro=redhat
[[ ! $distro && -f /etc/SuSE-release ]] && distro=suse
[[ ! $distro && -f /etc/debian_release ]] && distro=debian
[[ ! $distro && -f /etc/arch-release ]] && distro=arch
[[ ! $distro ]] && Fatal "\
Unable to determine package management style.
Invoke this script with the option '-p <style>', where <style> can be:
redhat -- RedHat-like, uses yum and rpm for package management.
fedora -- Fedora 22+, also RedHat-like, but uses dnf instead of yum.
suse -- SUSE-like, uses zypper and rpm.
debian -- Debian-like, uses apt and dpkg.
arch -- Archlinux, uses pacman.
We do not currently support other package management systems. Check the Intel's
documentation at https://software.intel.com/mkl/choose-download for other
install options."
echo >&2 "$0: Your system is using ${distro}-style package management."
fi
# Check for root.
if [[ "$(id -u)" -ne 0 ]]; then
echo >&2 "$0: You must be root to install MKL.
Restart this script using the 'sudo' command, as:
sudo $0 -sp $distro $package
We recommend adding the '-sp $distro' options to skip the MKL and distro
detection, since this has already been done. This minimizes the number of
programs invoked with the root privileges to keep your system safe from
unexpected or erroneous changes. Also, if you are setting the CC environment
variable, sudo might not allow it to propagate to the command that it invokes."
if [ -t 0 ]; then
echo; read -ep "Run the above sudo command now? [Y/n]:"
case $REPLY in
''|[Yy]*) set -x; exec sudo "$0" -sp "$distro" "$package"
esac
fi
exit 0
fi
# The install variants, each in a function to simplify error reporting.
# Each one invokes a subshell with a 'set -x' to to show system-modifying
# commands it runs. The subshells simply limit the scope of this diagnostics
# and avoid creating noise (if we were using 'set +x', it would be printed).
Install_redhat () {
# yum-utils contains yum-config-manager, in case the user does not have it.
( set -x
rpm --import $intel_key_url
yum -y install yum-utils &&
yum-config-manager --add-repo "$yum_repo" &&
yum -y install "$package" )
}
Install_fedora () {
( set -x
rpm --import $intel_key_url
dnf -y install 'dnf-command(config-manager)' &&
dnf config-manager --add-repo "$yum_repo" &&
dnf -y install "$package" )
}
Install_suse () {
# zypper bug until libzypp-17.6.4: '--gpg-auto-import-keys' is ignored.
# See https://github.com/openSUSE/zypper/issues/144#issuecomment-418685933
# We must disable gpg checks with '--no-gpg-checks'. I won't bend backwards
# as far as check the installed .so version...
( set -x
rpm --import $intel_key_url
zypper addrepo "$yum_repo" &&
zypper --gpg-auto-import-keys --no-gpg-checks \
--non-interactive install "$package" )
}
Install_debian () {
local keyring='/usr/share/keyrings/intel-sw-products.gpg' \
sources_d='/etc/apt/sources.list.d' \
trusted_d='/etc/apt/trusted.gpg.d' \
apt_maj= apt_min= apt_ver=
# apt before 1.2 does not understand the signed-by option, and always
# look for the keyring in their trusted.gpg.d directory. This is not
# considered a good security practice any more. If apt is old, add a link
# to the keyring file and remind the user to delete it when apt is upgraded.
IFS=' .' builtin read _ apt_maj apt_min _ < <(apt-get --version)
apt_ver=$(builtin printf '%03d%03d' $apt_maj $apt_min)
# Get alternative location of /etc/apt/sources.list.d, if so configured.
eval $(apt-config shell sources_d Dir::Etc::sourceparts/f \
trusted_d Dir::Etc::trustedparts/f)
# apt is much more involved to configure than other package managers, as fas
# as third-party security keys go.
( set -x;
apt-get update &&
apt-get install -y wget apt-transport-https ca-certificates gnupg &&
wget -qO- $intel_key_url | apt-key --keyring $keyring add - &&
echo "deb [signed-by=${keyring}] $apt_repo all main" \
> "$sources_d/intel-mkl.list" ) || return 1
if [[ $apt_ver < '001002' ]]; then
( set -x; ln -s "$keyring" "${trusted_d}/" ) || return 1
fi
( set +x
apt-get update &&
apt-get install -y "$package" ) || return 1
# Print the message after the large install, so the user may notice. I hope...
if [[ $apt_ver < '001002' ]]; then
echo >&2 "$0: Your apt-get version is earlier than 1.2.
This version does not understand individual repositories signing keys, and
trusts all keys in $trusted_d. We have created a link
$trusted_d/$(basename $keyring) pointing to the file
$keyring. If/when you upgrade your system to
a higher version of apt, removing this link will help make it more secure.
This is not considered a severe security issue, but separating keyrings is the
current recommended security practice."
fi
}
Install_arch () {
( set -x
echo y | pacman -Syu intel-mkl && # In pacman we don't specify the version
pacman -Q --info intel-mkl | grep -v None
)
}
# Register MKL .so libraries with the ld.so.
ConfigLdSo() {
[ -d /etc/ld.so.conf.d ] || return 0
type -t ldconfig >/dev/null || return 0
echo >&2 "$0: Configuring ld runtime bindings"
( set -x;
echo >/etc/ld.so.conf.d/intel-mkl.conf "\
/opt/intel/lib/intel64
/opt/intel/mkl/lib/intel64"
ldconfig )
}
# Invoke installation.
if Install_${distro} && ConfigLdSo; then
echo >&2 "$0: MKL package $package was successfully installed"
else
Fatal "MKL package $package installation FAILED.
Please open an issue with us at https://github.com/kaldi-asr/kaldi/ if you
believe this is a bug."
fi
#!/bin/bash
set -e
set -x
# need support c++17, so need gcc >= 8
# openfst
ngram=ngram-1.3.13
openfst=openfst-1.8.1
shared=true
export CPLUS_INCLUDE_PATH=$PWD/${openfst}/install/include/:$CPLUS_INCLUDE_PATH
export LD_LIBRARY_PATH=$PWD/${openfst}/install/lib/:$LD_LIBRARY_PATH
test -e ${ngram}.tar.gz || wget http://www.openfst.org/twiki/pub/GRM/NGramDownload/${ngram}.tar.gz
test -d ${ngram} || tar -xvf ${ngram}.tar.gz && chown -R root:root ${ngram}
if [ $shared == true ];then
pushd ${ngram} && ./configure --enable-shared && popd
else
pushd ${ngram} && ./configure --enable-static && popd
fi
pushd ${ngram} && make -j && make install && popd
#!/usr/bin/env bash
OPENBLAS_VERSION=0.3.13
WGET=${WGET:-wget}
set -e
if ! command -v gfortran 2>/dev/null; then
echo "$0: gfortran is not installed. Please install it, e.g. by:"
echo " apt-get install gfortran"
echo "(if on Debian or Ubuntu), or:"
echo " yum install gcc-gfortran"
echo "(if on RedHat/CentOS). On a Mac, if brew is installed, it's:"
echo " brew install gfortran"
exit 1
fi
tarball=OpenBLAS-$OPENBLAS_VERSION.tar.gz
rm -rf xianyi-OpenBLAS-* OpenBLAS OpenBLAS-*.tar.gz
if [ -d "$DOWNLOAD_DIR" ]; then
cp -p "$DOWNLOAD_DIR/$tarball" .
else
url=$($WGET -qO- "https://api.github.com/repos/xianyi/OpenBLAS/releases/tags/v${OPENBLAS_VERSION}" | python -c 'import sys,json;print(json.load(sys.stdin)["tarball_url"])')
test -n "$url"
$WGET -t3 -nv -O $tarball "$url"
fi
tar xzf $tarball
mv xianyi-OpenBLAS-* OpenBLAS
make PREFIX=$(pwd)/OpenBLAS/install USE_LOCKING=1 USE_THREAD=0 -C OpenBLAS all install
if [ $? -eq 0 ]; then
echo "OpenBLAS is installed successfully."
rm $tarball
fi
#!/bin/bash
set -e
set -x
# need support c++17, so need gcc >= 8
# openfst
openfst=openfst-1.8.1
shared=true
test -e ${openfst}.tar.gz || wget http://www.openfst.org/twiki/pub/FST/FstDownload/${openfst}.tar.gz
test -d ${openfst} || tar -xvf ${openfst}.tar.gz && chown -R root:root ${openfst}
if [ $shared == true ];then
pushd ${openfst} && ./configure --enable-shared --enable-compact-fsts --enable-compress --enable-const-fsts --enable-far --enable-linear-fsts --enable-lookahead-fsts --enable-mpdt --enable-ngram-fsts --enable-pdt --enable-python --enable-special --enable-bin --enable-grm --prefix ${PWD}/install && popd
else
pushd ${openfst} && ./configure --enable-static --enable-compact-fsts --enable-compress --enable-const-fsts --enable-far --enable-linear-fsts --enable-lookahead-fsts --enable-mpdt --enable-ngram-fsts --enable-pdt --enable-python --enable-special --enable-bin --enable-grm --prefix ${PWD}/install && popd
fi
pushd ${openfst} && make -j && make install && popd
suffix_path=$(python3 -c 'import sysconfig; import os; from pathlib import Path; site = sysconfig.get_paths()["purelib"]; site=Path(site); pwd = os.getcwd(); suffix = site.parts[-2:]; print(os.path.join(*suffix));')
wfst_so_path=${PWD}/${openfst}/install/lib/${suffix_path}
cp ${wfst_so_path}/pywrapfst.* $(python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])')
#!/bin/bash
set -e
set -x
pynini=pynini-2.1.4
openfst=openfst-1.8.1
LIBRARY_PATH=$PWD/${openfst}/install/lib
test -e ${pynini}.tar.gz || wget http://www.openfst.org/twiki/pub/GRM/PyniniDownload/${pynini}.tar.gz
test -d ${pynini} || tar -xvf ${pynini}.tar.gz && chown -R root:root ${pynini}
pushd ${pynini} && LIBRARY_PATH=$LIBRARY_PATH python setup.py install && popd
#!/usr/bin/env bash
current_path=`pwd`
current_dir=`basename "$current_path"`
if [ "tools" != "$current_dir" ]; then
echo "You should run this script in tools/ directory!!"
exit 1
fi
if [ ! -d liblbfgs-1.10 ]; then
echo Installing libLBFGS library to support MaxEnt LMs
bash extras/install_liblbfgs.sh || exit 1
fi
# http://www.speech.sri.com/projects/srilm/download.html
if [ ! -f srilm.tgz ] && [ ! -f srilm.tar.gz ]; then # Changed format type from tgz to tar.gz as the srilm v1.7.3 downloads as tar.gz
echo This script cannot install SRILM in a completely automatic
echo way because you need to put your address in a download form.
echo Please download SRILM from http://www.speech.sri.com/projects/srilm/download.html
echo put it in ./srilm.tar.gz , then run this script.
echo Note: You may have to rename the downloaded file to remove version name from filename eg: mv srilm-1.7.3.tar.gz srilm.tar.gz
exit 1
fi
! which gawk 2>/dev/null && \
echo "GNU awk is not installed so SRILM will probably not work correctly: refusing to install" && exit 1;
mkdir -p srilm
cd srilm
if [ -f ../srilm.tgz ]; then
tar -xvzf ../srilm.tgz # Old SRILM format
elif [ -f ../srilm.tar.gz ]; then
tar -xvzf ../srilm.tar.gz # Changed format type from tgz to tar.gz
fi
major=`awk -F. '{ print $1 }' RELEASE`
minor=`awk -F. '{ print $2 }' RELEASE`
micro=`awk -F. '{ print $3 }' RELEASE`
if [ $major -le 1 ] && [ $minor -le 7 ] && [ $micro -le 1 ]; then
echo "Detected version 1.7.1 or earlier. Applying patch."
patch -p0 < ../extras/srilm.patch
fi
# set the SRILM variable in the top-level Makefile to this directory.
cp Makefile tmpf
cat tmpf | awk -v pwd=`pwd` '/SRILM =/{printf("SRILM = %s\n", pwd); next;} {print;}' \
> Makefile || exit 1
rm tmpf
mtype=`sbin/machine-type`
echo HAVE_LIBLBFGS=1 >> common/Makefile.machine.$mtype
grep ADDITIONAL_INCLUDES common/Makefile.machine.$mtype | \
sed 's|$| -I$(SRILM)/../liblbfgs-1.10/include|' \
>> common/Makefile.machine.$mtype
grep ADDITIONAL_LDFLAGS common/Makefile.machine.$mtype | \
sed 's|$| -L$(SRILM)/../liblbfgs-1.10/lib/ -Wl,-rpath -Wl,$(SRILM)/../liblbfgs-1.10/lib/|' \
>> common/Makefile.machine.$mtype
make || exit
cd ..
(
[ ! -z "${SRILM}" ] && \
echo >&2 "SRILM variable is aleady defined. Undefining..." && \
unset SRILM
[ -f ./env.sh ] && . ./env.sh
[ ! -z "${SRILM}" ] && \
echo >&2 "SRILM config is already in env.sh" && exit
wd=`pwd`
wd=`readlink -f $wd || pwd`
echo "export SRILM=$wd/srilm"
dirs="\${PATH}"
for directory in $(cd srilm && find bin -type d ) ; do
dirs="$dirs:\${SRILM}/$directory"
done
echo "export PATH=$dirs"
) >> env.sh
echo >&2 "Installation of SRILM finished successfully"
echo >&2 "Please source the tools/env.sh in your path.sh to enable it"
--- dstruct/src/Trie.orig 2016-11-08 19:53:40.524000000 +0000
+++ dstruct/src/Trie.cc 2016-11-08 19:53:59.088000000 +0000
@@ -200,11 +200,14 @@
if (removedData == 0) {
Trie<KeyT,DataT> node;
if (sub.remove(keys[0], &node)) {
+#if !defined(__GNUC__) || !(__GNUC__ >= 4 && __GNUC_MINOR__ >= 9 || __GNUC__ > 4)
/*
* XXX: Call subtrie destructor explicitly since we're not
* passing the removed node to the caller.
+ * !!! Triggers bug with gcc >= 4.9 optimization !!!
*/
node.~Trie();
+#endif
return true;
} else {
return false;
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -44,6 +44,11 @@ add_arg('manifest_paths', str,
"You can provide multiple manifest files.",
nargs='+',
required=True)
add_arg('text_keys', str,
'text',
"keys of the text in manifest for building vocabulary. "
"You can provide multiple k.",
nargs='+')
# bpe
add_arg('spm_vocab_size', int, 0, "Vocab size for spm.")
add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm")
......@@ -58,10 +63,10 @@ def count_manifest(counter, text_feature, manifest_path):
line = text_feature.tokenize(line_json['text'])
counter.update(line)
def dump_text_manifest(fileobj, manifest_path):
def dump_text_manifest(fileobj, manifest_path, key='text'):
manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons:
fileobj.write(line_json['text'] + "\n")
fileobj.write(line_json[key] + "\n")
def main():
print_arguments(args, globals())
......@@ -78,7 +83,9 @@ def main():
fp = tempfile.NamedTemporaryFile(mode='w', delete=False)
for manifest_path in args.manifest_paths:
dump_text_manifest(fp, manifest_path)
text_keys = [args.text_keys] if type(args.text_keys) is not list else args.text_keys
for text_key in text_keys:
dump_text_manifest(fp, manifest_path, key=text_key)
fp.close()
# train
spm.SentencePieceTrainer.Train(
......
#!/usr/bin/env python3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""format manifest into wav.scp text.word [text.syllable text.phone]"""
import argparse
from pathlib import Path
from typing import Union
from deepspeech.frontend.utility import read_manifest
key_whitelist = set(['feat', 'text', 'syllable', 'phone'])
filename = {
'text': 'text.word',
'syllable': 'text.syllable',
'phone': 'text.phone',
'feat': 'wav.scp',
}
def dump_manifest(manifest_path, output_dir: Union[str, Path]):
output_dir = Path(output_dir).expanduser()
manifest_path = Path(manifest_path).expanduser()
manifest_jsons = read_manifest(manifest_path)
first_line = manifest_jsons[0]
file_map = {}
for k in first_line.keys():
if k not in key_whitelist:
continue
file_map[k] = open(output_dir / filename[k], 'w')
for line_json in manifest_jsons:
for k in line_json.keys():
if k not in key_whitelist:
continue
file_map[k].write(line_json['utt'] + ' ' + line_json[k] + '\n')
for _, file in file_map.items():
file.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="dump manifest to wav.scp text.word ...")
parser.add_argument("--manifest-path", type=str, help="path to manifest")
parser.add_argument(
"--output-dir",
type=str,
help="path to save outputs(audio and transcriptions)")
args = parser.parse_args()
dump_manifest(args.manifest_path, args.output_dir)
#!/bin/bash
if [ $# == 1 ];then
echo "usage: ${0} manifest_file"
exit -1
fi
manifest=$1
jq -S '.feat_shape[0]' ${manifest} | sort -nu
#!/usr/bin/env perl
# Copyright 2010-2012 Microsoft Corporation
# Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script takes a list of utterance-ids or any file whose first field
# of each line is an utterance-id, and filters an scp
# file (or any file whose "n-th" field is an utterance id), printing
# out only those lines whose "n-th" field is in id_list. The index of
# the "n-th" field is 1, by default, but can be changed by using
# the -f <n> switch
$exclude = 0;
$field = 1;
$shifted = 0;
do {
$shifted=0;
if ($ARGV[0] eq "--exclude") {
$exclude = 1;
shift @ARGV;
$shifted=1;
}
if ($ARGV[0] eq "-f") {
$field = $ARGV[1];
shift @ARGV; shift @ARGV;
$shifted=1
}
} while ($shifted);
if(@ARGV < 1 || @ARGV > 2) {
die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
"Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
"Note: only the first field of each line in id_list matters. With --exclude, prints\n" .
"only the lines that were *not* in id_list.\n" .
"Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
"If your older scripts (written before Oct 2014) stopped working and you used the\n" .
"-f option, add 1 to the argument.\n" .
"See also: utils/filter_scp.pl .\n";
}
$idlist = shift @ARGV;
open(F, "<$idlist") || die "Could not open id-list file $idlist";
while(<F>) {
@A = split;
@A>=1 || die "Invalid id-list file line $_";
$seen{$A[0]} = 1;
}
if ($field == 1) { # Treat this as special case, since it is common.
while(<>) {
$_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
# $1 is what we filter on.
if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
print $_;
}
}
} else {
while(<>) {
@A = split;
@A > 0 || die "Invalid scp file line $_";
@A >= $field || die "Invalid scp file line $_";
if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
print $_;
}
}
}
# tests:
# the following should print "foo 1"
# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo)
# the following should print "bar 2".
# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2)
\ No newline at end of file
#!/usr/bin/env python3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""format manifest with more metadata."""
import argparse
import functools
import json
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
from deepspeech.frontend.utility import load_cmvn
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.utility import add_arguments
from deepspeech.utils.utility import print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), kaldi")
add_arg('cmvn_path', str,
'examples/librispeech/data/mean_std.json',
"Filepath of cmvn.")
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
add_arg('vocab_path', str,
'examples/librispeech/data/vocab.txt',
"Filepath of the vocabulary.")
add_arg('manifest_paths', str,
None,
"Filepaths of manifests for building vocabulary. "
"You can provide multiple manifest files.",
nargs='+',
required=True)
# bpe
add_arg('spm_model_prefix', str, None,
"spm model prefix, spm_model_%(bpe_mode)_%(count_threshold), only need when `unit_type` is spm")
add_arg('output_path', str, None, "filepath of formated manifest.", required=True)
# yapf: disable
args = parser.parse_args()
def main():
print_arguments(args, globals())
fout = open(args.output_path, 'w', encoding='utf-8')
# get feat dim
mean, std = load_cmvn(args.cmvn_path, filetype='json')
feat_dim = mean.shape[0] #(D)
print(f"Feature dim: {feat_dim}")
text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix)
vocab_size = text_feature.vocab_size
print(f"Vocab size: {vocab_size}")
count = 0
for manifest_path in args.manifest_paths:
manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons:
# text: translation text, text1: transcript text.
# Currently only support joint-vocab, will add separate vocabs setting.
line = line_json['text']
tokens = text_feature.tokenize(line)
tokenids = text_feature.featurize(line)
line_json['token'] = tokens
line_json['token_id'] = tokenids
line_json['token_shape'] = (len(tokenids), vocab_size)
line = line_json['text1']
tokens = text_feature.tokenize(line)
tokenids = text_feature.featurize(line)
line_json['token1'] = tokens
line_json['token_id1'] = tokenids
line_json['token_shape1'] = (len(tokenids), vocab_size)
feat_shape = line_json['feat_shape']
assert isinstance(feat_shape, (list, tuple)), type(feat_shape)
if args.feat_type == 'raw':
feat_shape.append(feat_dim)
else: # kaldi
raise NotImplementedError('no support kaldi feat now!')
fout.write(json.dumps(line_json) + '\n')
count += 1
print(f"Examples number: {count}")
fout.close()
if __name__ == '__main__':
main()
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# 2013-2016 Johns Hopkins University (author: Daniel Povey)
# 2015 Hainan Xu
# 2015 Guoguo Chen
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Adds disambiguation symbols to a lexicon.
# Outputs still in the normal lexicon format.
# Disambig syms are numbered #1, #2, #3, etc. (#0
# reserved for symbol in grammar).
# Outputs the number of disambig syms to the standard output.
# With the --pron-probs option, expects the second field
# of each lexicon line to be a pron-prob.
# With the --sil-probs option, expects three additional
# fields after the pron-prob, representing various components
# of the silence probability model.
$pron_probs = 0;
$sil_probs = 0;
$first_allowed_disambig = 1;
for ($n = 1; $n <= 3 && @ARGV > 0; $n++) {
if ($ARGV[0] eq "--pron-probs") {
$pron_probs = 1;
shift @ARGV;
}
if ($ARGV[0] eq "--sil-probs") {
$sil_probs = 1;
shift @ARGV;
}
if ($ARGV[0] eq "--first-allowed-disambig") {
$first_allowed_disambig = 0 + $ARGV[1];
if ($first_allowed_disambig < 1) {
die "add_lex_disambig.pl: invalid --first-allowed-disambig option: $first_allowed_disambig\n";
}
shift @ARGV;
shift @ARGV;
}
}
if (@ARGV != 2) {
die "Usage: add_lex_disambig.pl [opts] <lexicon-in> <lexicon-out>\n" .
"This script adds disambiguation symbols to a lexicon in order to\n" .
"make decoding graphs determinizable; it adds pseudo-phone\n" .
"disambiguation symbols #1, #2 and so on at the ends of phones\n" .
"to ensure that all pronunciations are different, and that none\n" .
"is a prefix of another.\n" .
"It prints to the standard output the number of the largest-numbered" .
"disambiguation symbol that was used.\n" .
"\n" .
"Options: --pron-probs Expect pronunciation probabilities in the 2nd field\n" .
" --sil-probs [should be with --pron-probs option]\n" .
" Expect 3 extra fields after the pron-probs, for aspects of\n" .
" the silence probability model\n" .
" --first-allowed-disambig <n> The number of the first disambiguation symbol\n" .
" that this script is allowed to add. By default this is\n" .
" #1, but you can set this to a larger value using this option.\n" .
"e.g.:\n" .
" add_lex_disambig.pl lexicon.txt lexicon_disambig.txt\n" .
" add_lex_disambig.pl --pron-probs lexiconp.txt lexiconp_disambig.txt\n" .
" add_lex_disambig.pl --pron-probs --sil-probs lexiconp_silprob.txt lexiconp_silprob_disambig.txt\n";
}
$lexfn = shift @ARGV;
$lexoutfn = shift @ARGV;
open(L, "<$lexfn") || die "Error opening lexicon $lexfn";
# (1) Read in the lexicon.
@L = ( );
while(<L>) {
@A = split(" ", $_);
push @L, join(" ", @A);
}
# (2) Work out the count of each phone-sequence in the
# lexicon.
foreach $l (@L) {
@A = split(" ", $l);
shift @A; # Remove word.
if ($pron_probs) {
$p = shift @A;
if (!($p > 0.0 && $p <= 1.0)) { die "Bad lexicon line $l (expecting pron-prob as second field)"; }
}
if ($sil_probs) {
$silp = shift @A;
if (!($silp > 0.0 && $silp <= 1.0)) { die "Bad lexicon line $l for silprobs"; }
$correction = shift @A;
if ($correction <= 0.0) { die "Bad lexicon line $l for silprobs"; }
$correction = shift @A;
if ($correction <= 0.0) { die "Bad lexicon line $l for silprobs"; }
}
if (!(@A)) {
die "Bad lexicon line $1, no phone in phone list";
}
$count{join(" ",@A)}++;
}
# (3) For each left sub-sequence of each phone-sequence, note down
# that it exists (for identifying prefixes of longer strings).
foreach $l (@L) {
@A = split(" ", $l);
shift @A; # Remove word.
if ($pron_probs) { shift @A; } # remove pron-prob.
if ($sil_probs) {
shift @A; # Remove silprob
shift @A; # Remove silprob
}
while(@A > 0) {
pop @A; # Remove last phone
$issubseq{join(" ",@A)} = 1;
}
}
# (4) For each entry in the lexicon:
# if the phone sequence is unique and is not a
# prefix of another word, no diambig symbol.
# Else output #1, or #2, #3, ... if the same phone-seq
# has already been assigned a disambig symbol.
open(O, ">$lexoutfn") || die "Opening lexicon file $lexoutfn for writing.\n";
# max_disambig will always be the highest-numbered disambiguation symbol that
# has been used so far.
$max_disambig = $first_allowed_disambig - 1;
foreach $l (@L) {
@A = split(" ", $l);
$word = shift @A;
if ($pron_probs) {
$pron_prob = shift @A;
}
if ($sil_probs) {
$sil_word_prob = shift @A;
$word_sil_correction = shift @A;
$prev_nonsil_correction = shift @A
}
$phnseq = join(" ", @A);
if (!defined $issubseq{$phnseq}
&& $count{$phnseq} == 1) {
; # Do nothing.
} else {
if ($phnseq eq "") { # need disambig symbols for the empty string
# that are not use anywhere else.
$max_disambig++;
$reserved_for_the_empty_string{$max_disambig} = 1;
$phnseq = "#$max_disambig";
} else {
$cur_disambig = $last_used_disambig_symbol_of{$phnseq};
if (!defined $cur_disambig) {
$cur_disambig = $first_allowed_disambig;
} else {
$cur_disambig++; # Get a number that has not been used yet for
# this phone sequence.
}
while (defined $reserved_for_the_empty_string{$cur_disambig}) {
$cur_disambig++;
}
if ($cur_disambig > $max_disambig) {
$max_disambig = $cur_disambig;
}
$last_used_disambig_symbol_of{$phnseq} = $cur_disambig;
$phnseq = $phnseq . " #" . $cur_disambig;
}
}
if ($pron_probs) {
if ($sil_probs) {
print O "$word\t$pron_prob\t$sil_word_prob\t$word_sil_correction\t$prev_nonsil_correction\t$phnseq\n";
} else {
print O "$word\t$pron_prob\t$phnseq\n";
}
} else {
print O "$word\t$phnseq\n";
}
}
print $max_disambig . "\n";
\ No newline at end of file
#!/bin/bash
# Copyright 2015 Yajie Miao (Carnegie Mellon University)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script compiles the lexicon and CTC tokens into FSTs. FST compiling slightly differs between the
# phoneme and character-based lexicons.
set -eo pipefail
. utils/parse_options.sh
if [ $# -ne 3 ]; then
echo "usage: utils/fst/compile_lexicon_token_fst.sh <dict-src-dir> <tmp-dir> <lang-dir>"
echo "e.g.: utils/fst/compile_lexicon_token_fst.sh data/local/dict data/local/lang_tmp data/lang"
echo "<dict-src-dir> should contain the following files:"
echo "lexicon.txt lexicon_numbers.txt units.txt"
echo "options: "
exit 1;
fi
srcdir=$1
tmpdir=$2
dir=$3
mkdir -p $dir $tmpdir
[ -f path.sh ] && . ./path.sh
cp $srcdir/units.txt $dir
# Add probabilities to lexicon entries. There is in fact no point of doing this here since all the entries have 1.0.
# But utils/make_lexicon_fst.pl requires a probabilistic version, so we just leave it as it is.
perl -ape 's/(\S+\s+)(.+)/${1}1.0\t$2/;' < $srcdir/lexicon.txt > $tmpdir/lexiconp.txt || exit 1;
# Add disambiguation symbols to the lexicon. This is necessary for determinizing the composition of L.fst and G.fst.
# Without these symbols, determinization will fail.
# default first disambiguation is #1
ndisambig=`utils/fst/add_lex_disambig.pl $tmpdir/lexiconp.txt $tmpdir/lexiconp_disambig.txt`
# add #0 (#0 reserved for symbol in grammar).
ndisambig=$[$ndisambig+1];
( for n in `seq 0 $ndisambig`; do echo '#'$n; done ) > $tmpdir/disambig.list
# Get the full list of CTC tokens used in FST. These tokens include <eps>, the blank <blk>,
# the actual model unit, and the disambiguation symbols.
cat $srcdir/units.txt | awk '{print $1}' > $tmpdir/units.list
(echo '<eps>';) | cat - $tmpdir/units.list $tmpdir/disambig.list | awk '{print $1 " " (NR-1)}' > $dir/tokens.txt
# ctc_token_fst_corrected is too big and too slow for character based chinese modeling,
# so here just use simple ctc_token_fst
utils/fst/ctc_token_fst.py --token_file $dir/tokens.txt | \
fstcompile --isymbols=$dir/tokens.txt --osymbols=$dir/tokens.txt --keep_isymbols=false --keep_osymbols=false | \
fstarcsort --sort_type=olabel > $dir/T.fst || exit 1;
# Encode the words with indices. Will be used in lexicon and language model FST compiling.
cat $tmpdir/lexiconp.txt | awk '{print $1}' | sort | awk '
BEGIN {
print "<eps> 0";
}
{
printf("%s %d\n", $1, NR);
}
END {
printf("#0 %d\n", NR+1);
printf("<s> %d\n", NR+2);
printf("</s> %d\n", NR+3);
}' > $dir/words.txt || exit 1;
# Now compile the lexicon FST. Depending on the size of your lexicon, it may take some time.
token_disambig_symbol=`grep \#0 $dir/tokens.txt | awk '{print $2}'`
word_disambig_symbol=`grep \#0 $dir/words.txt | awk '{print $2}'`
utils/fst/make_lexicon_fst.pl --pron-probs $tmpdir/lexiconp_disambig.txt 0 "sil" '#'$ndisambig | \
fstcompile --isymbols=$dir/tokens.txt --osymbols=$dir/words.txt \
--keep_isymbols=false --keep_osymbols=false | \
fstaddselfloops "echo $token_disambig_symbol |" "echo $word_disambig_symbol |" | \
fstarcsort --sort_type=olabel > $dir/L.fst || exit 1;
echo "Lexicon and Token FSTs compiling succeeded"
\ No newline at end of file
#!/usr/bin/env python3
import argparse
def main(args):
"""Token Transducer"""
# <eps> entry
print('0 1 <eps> <eps>')
# skip begining and ending <blank>
print('1 1 <blank> <eps>')
print('2 2 <blank> <eps>')
# <eps> exit
print('2 0 <eps> <eps>')
# linking `token` between node 1 and node 2
with open(args.token_file, 'r') as fin:
node = 3
for entry in fin:
fields = entry.strip().split(' ')
phone = fields[0]
if phone == '<eps>' or phone == '<blank>':
continue
elif '#' in phone:
# disambiguous phone
# `token` maybe ending with disambiguous symbol
print('{} {} {} {}'.format(0, 0, '<eps>', phone))
else:
# eating `token`
print('{} {} {} {}'.format(1, node, phone, phone))
# remove repeating `token`
print('{} {} {} {}'.format(node, node, phone, '<eps>'))
# leaving `token`
print('{} {} {} {}'.format(node, 2, '<eps>', '<eps>'))
node += 1
# Fianl node
print('0')
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='FST: CTC Token FST transducer')
parser.add_argument(
'--token_file',
required=True,
help='e2e model token file. line: token(char/phone/spm/disambigous)')
args = parser.parse_args()
main(args)
#!/usr/bin/env python3
import argparse
def il(n):
"""ilabel"""
return n + 1
def ol(n):
"""olabel"""
return n + 1
def s(n):
"""state"""
return n
def main(args):
with open(args.token_file) as f:
lines = f.readlines()
# token count w/0 <blank> <eps>
phone_count = 0
disambig_count = 0
for line in lines:
sp = line.strip().split()
phone = sp[0]
if phone == '<eps>' or phone == '<blank>':
continue
if phone.startswith('#'):
disambig_count += 1
else:
phone_count += 1
# 1. add start state
# first token is <blank>:0
print('0 0 {} 0'.format(il(0)))
# 2. 0 -> i, i -> i, i -> 0
# non-blank token start from 1
for i in range(1, phone_count + 1):
# eating `token`
print('0 {} {} {}'.format(s(i), il(i), ol(i)))
# remove repeating `token`
print('{} {} {} 0'.format(s(i), s(i), il(i)))
# skip ending <blank> `token`
print('{} 0 {} 0'.format(s(i), il(0)))
# 3. i -> other phone
# non-blank token to other non-blank token
for i in range(1, phone_count + 1):
for j in range(1, phone_count + 1):
if i != j:
print('{} {} {} {}'.format(s(i), s(j), il(j), ol(j)))
# 4. add disambiguous arcs on every final state
# blank and non-blank token maybe ending with disambiguous `token`
for i in range(0, phone_count + 1):
for j in range(phone_count + 2, phone_count + disambig_count + 2):
print('{} {} {} {}'.format(s(i), s(i), 0, j))
# 5. every i is final state
# blank and non-blank `token` are final state
for i in range(0, phone_count + 1):
print(s(i))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='FST: CTC Token unfold FST transducer')
parser.add_argument(
'--token_file',
required=True,
help='e2e model token file. line: token(char/phone/spm/disambigous)')
args = parser.parse_args()
main(args)
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# 2015 Guoguo Chen
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script replaces epsilon with #0 on the input side only, of the G.fst
# acceptor.
while(<>){
if (/\s+#0\s+/) {
print STDERR "$0: ERROR: LM has word #0, " .
"which is reserved as disambiguation symbol\n";
exit 1;
}
s:^(\d+\s+\d+\s+)\<eps\>(\s+):$1#0$2:;
print;
}
\ No newline at end of file
#!/usr/bin/env perl
use warnings; #sed replacement for -w perl parameter
# Copyright 2010-2011 Microsoft Corporation
# 2013 Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# makes lexicon FST, in text form, from lexicon (pronunciation probabilities optional).
$pron_probs = 0;
if ((@ARGV > 0) && ($ARGV[0] eq "--pron-probs")) {
$pron_probs = 1;
shift @ARGV;
}
if (@ARGV != 1 && @ARGV != 3 && @ARGV != 4) {
print STDERR "Usage: make_lexicon_fst.pl [--pron-probs] lexicon.txt [silprob silphone [sil_disambig_sym]] >lexiconfst.txt\n\n";
print STDERR "Creates a lexicon FST that transduces phones to words, and may allow optional silence.\n\n";
print STDERR "Note: ordinarily, each line of lexicon.txt is:\n";
print STDERR " word phone1 phone2 ... phoneN;\n";
print STDERR "if the --pron-probs option is used, each line is:\n";
print STDERR " word pronunciation-probability phone1 phone2 ... phoneN.\n\n";
print STDERR "The probability 'prob' will typically be between zero and one, and note that\n";
print STDERR "it's generally helpful to normalize so the largest one for each word is 1.0, but\n";
print STDERR "this is your responsibility.\n\n";
print STDERR "The silence disambiguation symbol, e.g. something like #5, is used only\n";
print STDERR "when creating a lexicon with disambiguation symbols, e.g. L_disambig.fst,\n";
print STDERR "and was introduced to fix a particular case of non-determinism of decoding graphs.\n\n";
exit(1);
}
$lexfn = shift @ARGV;
if (@ARGV == 0) {
$silprob = 0.0;
} elsif (@ARGV == 2) {
($silprob,$silphone) = @ARGV;
} else {
($silprob,$silphone,$sildisambig) = @ARGV;
}
if ($silprob != 0.0) {
$silprob < 1.0 || die "Sil prob cannot be >= 1.0";
$silcost = -log($silprob);
$nosilcost = -log(1.0 - $silprob);
}
open(L, "<$lexfn") || die "Error opening lexicon $lexfn";
if ( $silprob == 0.0 ) { # No optional silences: just have one (loop+final) state which is numbered zero.
$loopstate = 0;
$nextstate = 1; # next unallocated state.
while (<L>) {
@A = split(" ", $_);
@A == 0 && die "Empty lexicon line.";
foreach $a (@A) {
if ($a eq "<eps>") {
die "Bad lexicon line $_ (<eps> is forbidden)";
}
}
$w = shift @A;
if (! $pron_probs) {
$pron_cost = 0.0;
} else {
$pron_prob = shift @A;
if (! defined $pron_prob || !($pron_prob > 0.0 && $pron_prob <= 1.0)) {
die "Bad pronunciation probability in line $_";
}
$pron_cost = -log($pron_prob);
}
if ($pron_cost != 0.0) { $pron_cost_string = "\t$pron_cost"; } else { $pron_cost_string = ""; }
$s = $loopstate;
$word_or_eps = $w;
while (@A > 0) {
$p = shift @A;
if (@A > 0) {
$ns = $nextstate++;
} else {
$ns = $loopstate;
}
print "$s\t$ns\t$p\t$word_or_eps$pron_cost_string\n";
$word_or_eps = "<eps>";
$pron_cost_string = ""; # so we only print it on the first arc of the word.
$s = $ns;
}
}
print "$loopstate\t0\n"; # final-cost.
} else { # have silence probs.
$startstate = 0;
$loopstate = 1;
$silstate = 2; # state from where we go to loopstate after emitting silence.
print "$startstate\t$loopstate\t<eps>\t<eps>\t$nosilcost\n"; # no silence.
if (!defined $sildisambig) {
print "$startstate\t$loopstate\t$silphone\t<eps>\t$silcost\n"; # silence.
print "$silstate\t$loopstate\t$silphone\t<eps>\n"; # no cost.
$nextstate = 3;
} else {
$disambigstate = 3;
$nextstate = 4;
print "$startstate\t$disambigstate\t$silphone\t<eps>\t$silcost\n"; # silence.
print "$silstate\t$disambigstate\t$silphone\t<eps>\n"; # no cost.
print "$disambigstate\t$loopstate\t$sildisambig\t<eps>\n"; # silence disambiguation symbol.
}
while (<L>) {
@A = split(" ", $_);
$w = shift @A;
if (! $pron_probs) {
$pron_cost = 0.0;
} else {
$pron_prob = shift @A;
if (! defined $pron_prob || !($pron_prob > 0.0 && $pron_prob <= 1.0)) {
die "Bad pronunciation probability in line $_";
}
$pron_cost = -log($pron_prob);
}
if ($pron_cost != 0.0) { $pron_cost_string = "\t$pron_cost"; } else { $pron_cost_string = ""; }
$s = $loopstate;
$word_or_eps = $w;
while (@A > 0) {
$p = shift @A;
if (@A > 0) {
$ns = $nextstate++;
print "$s\t$ns\t$p\t$word_or_eps$pron_cost_string\n";
$word_or_eps = "<eps>";
$pron_cost_string = ""; $pron_cost = 0.0; # so we only print it the 1st time.
$s = $ns;
} elsif (!defined($silphone) || $p ne $silphone) {
# This is non-deterministic but relatively compact,
# and avoids epsilons.
$local_nosilcost = $nosilcost + $pron_cost;
$local_silcost = $silcost + $pron_cost;
print "$s\t$loopstate\t$p\t$word_or_eps\t$local_nosilcost\n";
print "$s\t$silstate\t$p\t$word_or_eps\t$local_silcost\n";
} else {
# no point putting opt-sil after silence word.
print "$s\t$loopstate\t$p\t$word_or_eps$pron_cost_string\n";
}
}
}
print "$loopstate\t0\n"; # final-cost.
}
\ No newline at end of file
#!/bin/bash
if [ -f path.sh ]; then . path.sh; fi
lm_dir=$1
src_lang=$2
tgt_lang=$3
arpa_lm=${lm_dir}/lm.arpa
[ ! -f $arpa_lm ] && { echo "No such file $arpa_lm"; exit 1;}
rm -rf $tgt_lang
cp -r $src_lang $tgt_lang
# Compose the language model to FST
# grep -i或--ignore-case 忽略字符大小写的差别。
# grep -v或--revert-match 反转查找。
# arpa2fst: remove the embedded symbols from the FST
# arpa2fst: make sure there are no out-of-vocabulary words in the language model
# arpa2fst: remove "illegal" sequences of the start and end-ofsentence symbols
# eps2disambig.pl: replace epsilons on the input side with the special disambiguation symbol #0.
# s2eps.pl: replaces <s> and </s> with <eps> (on both input and output sides), for the G.fst acceptor.
# G.fst, the disambiguation symbol #0 only appears on the input side
# do eps2disambig.pl and s2eps.pl maybe just for fallowing `fstrmepsilon`.
cat $arpa_lm | \
grep -v '<s> <s>' | \
grep -v '</s> <s>' | \
grep -v '</s> </s>' | \
grep -v -i '<unk>' | \
grep -v -i '<spoken_noise>' | \
arpa2fst --read-symbol-table=$tgt_lang/words.txt --keep-symbols=true - | fstprint | \
utils/fst/eps2disambig.pl | utils/fst/s2eps.pl | fstcompile --isymbols=$tgt_lang/words.txt \
--osymbols=$tgt_lang/words.txt --keep_isymbols=false --keep_osymbols=false | \
fstrmepsilon | fstarcsort --sort_type=ilabel > $tgt_lang/G.fst
echo "Checking how stochastic G is (the first of these numbers should be small):"
fstisstochastic $tgt_lang/G.fst
# Compose the token, lexicon and language-model FST into the final decoding graph
# minimization: the same as minimization algorithm that applies to weighted acceptors;
# the only change relevant here is that it avoids pushing weights,
# hence preserving stochasticity
fsttablecompose $tgt_lang/L.fst $tgt_lang/G.fst | fstdeterminizestar --use-log=true | \
fstminimizeencoded | fstarcsort --sort_type=ilabel > $tgt_lang/LG.fst || exit 1;
fsttablecompose $tgt_lang/T.fst $tgt_lang/LG.fst > $tgt_lang/TLG.fst || exit 1;
echo "Composing decoding graph TLG.fst succeeded"
#rm -r $tgt_lang/LG.fst # We don't need to keep this intermediate FST
\ No newline at end of file
#!/usr/bin/env python3
import argparse
def main(args):
# load `unit` or `vocab` file
unit_table = set()
with open(args.unit_file, 'r') as fin:
for line in fin:
unit = line.strip()
unit_table.add(unit)
def contain_oov(units):
for unit in units:
if unit not in unit_table:
return True
return False
# load spm model
bpemode = args.bpemodel
if bpemode:
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.Load(sys.bpemodel)
# used to filter polyphone
lexicon_table = set()
with open(args.in_lexicon, 'r') as fin, \
open(args.out_lexicon, 'w') as fout:
for line in fin:
word = line.split()[0]
if word == 'SIL' and not bpemode: # `sil` might be a valid piece in bpemodel
continue
elif word == '<SPOKEN_NOISE>':
continue
else:
# each word only has one pronunciation for e2e system
if word in lexicon_table:
continue
if bpemode:
pieces = sp.EncodeAsPieces(word)
if contain_oov(pieces):
print('Ignoring words {}, which contains oov unit'.
format(''.join(word).strip('▁')))
continue
chars = ' '.join(
[p if p in unit_table else '<unk>' for p in pieces])
else:
# ignore words with OOV
if contain_oov(word):
print('Ignoring words {}, which contains oov unit'.
format(word))
continue
# Optional, append ▁ in front of english word
# we assume the model unit of our e2e system is char now.
if word.encode('utf8').isalpha() and '▁' in unit_table:
word = '▁' + word
chars = ' '.join(word) # word is a char list
fout.write('{} {}\n'.format(word, chars))
lexicon_table.add(word)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='FST: preprae e2e(char/spm) dict')
parser.add_argument(
'--unit_file',
required=True,
help='e2e model unit file(lang_char.txt/vocab.txt). line: char/spm_pices'
)
parser.add_argument(
'--in_lexicon',
required=True,
help='raw lexicon file. line: word ph0 ... phn')
parser.add_argument(
'--out_lexicon',
required=True,
help='output lexicon file. line: word char0 ... charn')
parser.add_argument('--bpemodel', default=None, help='bpemodel')
args = parser.parse_args()
print(args)
main(args)
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script removes lines that contain these OOVs on either the
# third or fourth fields of the line. It is intended to remove arcs
# with OOVs on, from FSTs (probably compiled from ARPAs with OOVs in).
if ( @ARGV < 1 && @ARGV > 2) {
die "Usage: remove_oovs.pl unk_list.txt [ printed-fst ]\n";
}
$unklist = shift @ARGV;
open(S, "<$unklist") || die "Failed opening unknown-symbol list $unklist\n";
while(<S>){
@A = split(" ", $_);
@A == 1 || die "Bad line in unknown-symbol list: $_";
$unk{$A[0]} = 1;
}
$num_removed = 0;
while(<>){
@A = split(" ", $_);
if(defined $unk{$A[2]} || defined $unk{$A[3]}) {
$num_removed++;
} else {
print;
}
}
print STDERR "remove_oovs.pl: removed $num_removed lines.\n";
#!/usr/bin/env python3
import argparse
def main(args):
# skip <blank> `token`
print('0 0 <blank> <eps>')
with open(args.token_file, 'r') as fin:
for entry in fin:
fields = entry.strip().split(' ')
phone = fields[0]
if phone == '<eps>' or phone == '<blank>':
continue
elif '#' in phone:
# disambiguous phone
# maybe add disambiguous `token`
print('{} {} {} {}'.format(0, 0, '<eps>', phone))
else:
# eating `token`
print('{} {} {} {}'.format(0, 0, phone, phone))
# final state
print('0')
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='FST: RNN-T Token FST transducer')
parser.add_argument(
'--token_file',
required=True,
help='e2e model token file. line: token(char/phone/spm/disambigous)')
args = parser.parse_args()
main(args)
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script replaces <s> and </s> with <eps> (on both input and output sides),
# for the G.fst acceptor.
while(<>){
@A = split(" ", $_);
if ( @A >= 4 ) {
if ($A[2] eq "<s>" || $A[2] eq "</s>") { $A[2] = "<eps>"; }
if ($A[3] eq "<s>" || $A[3] eq "</s>") { $A[3] = "<eps>"; }
}
print join("\t", @A) . "\n";
}
\ No newline at end of file
#!/usr/bin/env python3
"""Manifest file to key-value files."""
import argparse
import functools
from pathlib import Path
from utils.utility import add_arguments
from utils.utility import print_arguments
from utils.utility import read_manifest
def main(args):
print_arguments(args, globals())
count = 0
outdir = Path(args.output_path)
wav_scp = outdir / 'wav.scp'
dur_scp = outdir / 'duration'
text_scp = outdir / 'text'
manifest_jsons = read_manifest(args.manifest_path)
with wav_scp.open('w') as fwav, dur_scp.open('w') as fdur, text_scp.open(
'w') as ftxt:
for line_json in manifest_jsons:
utt = line_json['utt']
feat = line_json['feat']
file_ext = Path(feat).suffix # .wav
text = line_json['text']
feat_shape = line_json['feat_shape']
dur = feat_shape[0]
feat_dim = feat_shape[1]
if 'token' in line_json:
tokens = line_json['token']
tokenids = line_json['token_id']
token_shape = line_json['token_shape']
token_len = token_shape[0]
vocab_dim = token_shape[1]
if file_ext == '.wav':
fwav.write(f"{utt} {feat}\n")
fdur.write(f"{utt} {dur}\n")
ftxt.write(f"{utt} {text}\n")
count += 1
print(f"Examples number: {count}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('manifest_path', str,
'data/librispeech/manifest.train',
"Filepath of manifest to compute normalizer's mean and stddev.")
add_arg('output_path', str,
'data/train',
"dir path to dump wav.scp/duaration/text files.")
# yapf: disable
args = parser.parse_args()
main(args)
......@@ -22,7 +22,7 @@ lmbin=${2}.klm.bin
# https://kheafield.com/code/kenlm/estimation/
echo "build arpa lm."
lmplz -o ${order} -S ${mem} --prune ${prune} < ${text} >${arpa} || { echo "train kenlm error!"; exit -1; }
lmplz -o ${order} -S ${mem} --prune ${prune} < ${text} > ${arpa} || { echo "train kenlm error!"; exit -1; }
# https://kheafield.com/code/kenlm/
echo "build binary lm."
......
......@@ -18,7 +18,13 @@ function clean() {
}
trap clean EXIT
cp ${ckpt_prefix}.* ${output}
# ckpt_prfix dir
if [ -d ${ckpt_prefix} ];then
cp -r ${ckpt_prefix} ${output}
fi
# ckpt_prfix.{json,...}
cp ${ckpt_prefix}.* ${output}
# model config, mean std, vocab
cp ${model_config} ${mean_std} ${vocab} ${output}
tar zcvf release.tar.gz ${output}
......
......@@ -11,11 +11,93 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import json
import os
import tarfile
import zipfile
from typing import Text
from paddle.dataset.common import md5file
__all__ = [
"check_md5sum", "getfile_insensitive", "download_multi", "download",
"unpack", "unzip", "md5file", "print_arguments", "add_arguments",
"read_manifest"
]
def read_manifest(manifest_path):
"""Load and parse manifest file.
Args:
manifest_path ([type]): Manifest file to load and parse.
Raises:
IOError: If failed to parse the manifest.
Returns:
List[dict]: Manifest parsing results.
"""
manifest = []
for json_line in open(manifest_path, 'r'):
try:
json_data = json.loads(json_line)
except Exception as e:
raise IOError("Error reading manifest: %s" % str(e))
return manifest
def print_arguments(args, info=None):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
filename = ""
if info:
filename = info["__file__"]
filename = os.path.basename(filename)
print(f"----------- {filename} Configuration Arguments -----------")
for arg, value in sorted(vars(args).items()):
print("%s: %s" % (arg, value))
print("-----------------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
def md5file(fname):
hash_md5 = hashlib.md5()
f = open(fname, "rb")
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
f.close()
return hash_md5.hexdigest()
def getfile_insensitive(path):
......@@ -54,6 +136,19 @@ def download(url, md5sum, target_dir):
return filepath
def check_md5sum(filepath: Text, md5sum: Text) -> bool:
"""check md5sum of file.
Args:
filepath (Text): [description]
md5sum (Text): [description]
Returns:
bool: same or not.
"""
return md5file(filepath) == md5sum
def unpack(filepath, target_dir, rm_tar=False):
"""Unpack the file to the target_dir."""
print("Unpacking %s ..." % filepath)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册