提交 617605c8 编写于 作者: C chenfeiyu

place parakeet into Parakeet/parakeet, and add tests

上级 5ac19fa5
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
\ No newline at end of file
{
"python.pythonPath": "/Users/chenfeiyu/miniconda3/envs/paddle/bin/python"
}
\ No newline at end of file
......@@ -3,7 +3,17 @@ functions to make batch for arrays which satisfy some conditions.
"""
import numpy as np
def text_collate(minibatch):
class TextIDBatcher(object):
"""A wrapper class for a function to build a functor, which holds the configs to pass to the function."""
def __init__(self, pad_id=0, dtype=np.int64):
self.pad_id = pad_id
self.dtype = dtype
def __call__(self, minibatch):
out = batch_text_id(minibatch, pad_id=self.pad_id, dtype=self.dtype)
return out
def batch_text_id(minibatch, pad_id=0, dtype=np.int64):
"""
minibatch: List[Example]
Example: ndarray, shape(T,), dtype: int64
......@@ -17,15 +27,25 @@ def text_collate(minibatch):
batch = []
for example in minibatch:
pad_len = max_len - example.shape[0]
batch.append(np.pad(example, [(0, pad_len)], mode='constant', constant_values=0))
batch.append(np.pad(example, [(0, pad_len)], mode='constant', constant_values=pad_id))
return np.array(batch, dtype=dtype)
class WavBatcher(object):
def __init__(self, pad_value=0., dtype=np.float32):
self.pad_value = pad_value
self.dtype = dtype
return np.array(batch, dtype=np.int64)
def __call__(self, minibatch):
out = batch_wav(minibatch, pad_value=self.pad_value, dtype=self.dtype)
return out
def wav_collate(minibatch):
def batch_wav(minibatch, pad_value=0., dtype=np.float32):
"""
minibatch: List[Example]
Example: ndarray, shape(C, T) for multi-channel wav, shape(T,) for mono-channel wav, dtype: float32
"""
# detect data format, maybe better to specify it in __init__
peek_example = minibatch[0]
if len(peek_example.shape) == 1:
mono_channel = True
......@@ -39,13 +59,23 @@ def wav_collate(minibatch):
for example in minibatch:
pad_len = max_len - example.shape[-1]
if mono_channel:
batch.append(np.pad(example, [(0, pad_len)], mode='constant', constant_values=0.))
batch.append(np.pad(example, [(0, pad_len)], mode='constant', constant_values=pad_value))
else:
batch.append(np.pad(example, [(0, 0), (0, pad_len)], mode='constant', constant_values=0.)) # what about PCM, no
batch.append(np.pad(example, [(0, 0), (0, pad_len)], mode='constant', constant_values=pad_value)) # what about PCM, no
return np.array(batch, dtype=dtype)
class SpecBatcher(object):
def __init__(self, pad_value=0., dtype=np.float32):
self.pad_value = pad_value
self.dtype = dtype
return np.array(batch, dtype=np.float32)
def __call__(self, minibatch):
out = batch_spec(minibatch, pad_value=self.pad_value, dtype=self.dtype)
return out
def spec_collate(minibatch):
def batch_spec(minibatch, pad_value=0., dtype=np.float32):
"""
minibatch: List[Example]
Example: ndarray, shape(C, F, T) for multi-channel spectrogram, shape(F, T) for mono-channel spectrogram, dtype: float32
......@@ -64,8 +94,8 @@ def spec_collate(minibatch):
for example in minibatch:
pad_len = max_len - example.shape[-1]
if mono_channel:
batch.append(np.pad(example, [(0, 0), (0, pad_len)], mode='constant', constant_values=0.))
batch.append(np.pad(example, [(0, 0), (0, pad_len)], mode='constant', constant_values=pad_value))
else:
batch.append(np.pad(example, [(0, 0), (0, 0), (0, pad_len)], mode='constant', constant_values=0.)) # what about PCM, no
batch.append(np.pad(example, [(0, 0), (0, 0), (0, pad_len)], mode='constant', constant_values=pad_value)) # what about PCM, no
return np.array(batch, dtype=np.float32)
\ No newline at end of file
return np.array(batch, dtype=dtype)
\ No newline at end of file
from sampler import SequentialSampler, RandomSampler, BatchSampler
from .sampler import SequentialSampler, RandomSampler, BatchSampler
class DataLoader(object):
def __init__(self, dataset, batch_size=1, collate_fn = lambda x: x,
sampler=None, shuffle=False, batch_sampler=None, drop_last=False):
class DataCargo(object):
def __init__(self, dataset, batch_size=1, sampler=None,
shuffle=False, batch_sampler=None, drop_last=False):
self.dataset = dataset
self.collate_fn = collate_fn
if batch_sampler is not None:
# auto_collation with custom batch_sampler
......@@ -14,20 +13,14 @@ class DataLoader(object):
'drop_last')
batch_size = None
drop_last = False
shuffle = False
elif batch_size is None:
# no auto_collation
if shuffle or drop_last:
raise ValueError('batch_size=None option disables auto-batching '
'and is mutually exclusive with '
'shuffle, and drop_last')
if sampler is None: # give default samplers
raise ValueError('batch sampler is none. then batch size must not be none.')
elif sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
......@@ -73,7 +66,7 @@ class DataIterator(object):
def __next__(self):
index = self._next_index() # may raise StopIteration, TODO(chenfeiyu): use dynamic batch size
minibatch = [self._dataset[i] for i in index] # we can abstract it, too to use dynamic batch size
minibatch = self.loader.collate_fn(minibatch) # list[Example] -> Batch
minibatch = self._dataset._batch_examples(minibatch) # list[Example] -> Batch
return minibatch
def _next_index(self):
......
class Dataset(object):
def __init__(self, lazy=True, stream=False):
# note that lazy and stream means two different things in our glossary
# lazy means to place preprocessing in __getitem__
# stram means the data source is itself a stream
self.lazy = lazy
self.stream = stream
def __init__(self):
pass
def _load_metadata(self):
raise NotImplementedError
def _get_example(self):
"""return a Record"""
"""return a Record (or Example, Instance according to your glossary)"""
raise NotImplementedError
def _batch_examples(self, minibatch):
"""get a list of examples, return a batch, whose structure is the same as an example"""
raise NotImplementedError
def _prepare_metadata(self):
......
......@@ -2,24 +2,20 @@ from pathlib import Path
import numpy as np
import pandas as pd
import librosa
import g2p
from .. import g2p
from sampler import SequentialSampler, RandomSampler, BatchSampler
from dataset import Dataset
from dataloader import DataLoader
from .sampler import SequentialSampler, RandomSampler, BatchSampler
from .dataset import Dataset
from .datacargo import DataCargo
from .batch import TextIDBatcher, SpecBatcher
from collate import text_collate, spec_collate
LJSPEECH_ROOT = Path("/Users/chenfeiyu/projects/LJSpeech-1.1")
class LJSpeech(Dataset):
def __init__(self, root=LJSPEECH_ROOT, lazy=True, stream=False):
super(LJSpeech, self).__init__(lazy, stream)
def __init__(self, root):
super(LJSpeech, self).__init__()
self.root = root
self.metadata = self._prepare_metadata() # we can do this just for luck
if self.stream:
self.examples_generator = self._read()
def _prepare_metadata(self):
# if pure-stream case, each _prepare_metadata returns a generator
csv_path = self.root.joinpath("metadata.csv")
......@@ -27,11 +23,6 @@ class LJSpeech(Dataset):
names=["fname", "raw_text", "normalized_text"])
return metadata
def _read(self):
for _, metadatum in self.metadata.iterrows():
example = self._get_example(metadatum)
yield example
def _get_example(self, metadatum):
"""All the code for generating an Example from a metadatum. If you want a
different preprocessing pipeline, you can override this method.
......@@ -62,44 +53,30 @@ class LJSpeech(Dataset):
phonemes = np.array(g2p.en.text_to_sequence(normalized_text), dtype=np.int64)
return (mag, mel, phonemes) # maybe we need to implement it as a map in the future
def _batch_examples(self, minibatch):
mag_batch = []
mel_batch = []
phoneme_batch = []
for example in minibatch:
mag, mel, phoneme = example
mag_batch.append(mag)
mel_batch.append(mel)
phoneme_batch.append(phoneme)
mag_batch = SpecBatcher(pad_value=0.)(mag_batch)
mel_batch = SpecBatcher(pad_value=0.)(mel_batch)
phoneme_batch = TextIDBatcher(pad_id=0)(phoneme_batch)
return (mag_batch, mel_batch, phoneme_batch)
def __getitem__(self, index):
if self.stream:
raise ValueError("__getitem__ is invalid in stream mode")
metadatum = self.metadata.iloc[index]
example = self._get_example(metadatum)
return example
def __iter__(self):
if self.stream:
for example in self.examples_generator:
yield example
else:
for i in range(len(self)):
yield self[i]
def __len__(self):
if self.stream:
raise ValueError("__len__ is invalid in stream mode")
return len(self.metadata)
def fn(minibatch):
mag_batch = []
mel_batch = []
phoneme_batch = []
for example in minibatch:
mag, mel, phoneme = example
mag_batch.append(mag)
mel_batch.append(mel)
phoneme_batch.append(phoneme)
mag_batch = spec_collate(mag_batch)
mel_batch = spec_collate(mel_batch)
phoneme_batch = text_collate(phoneme_batch)
return (mag_batch, mel_batch, phoneme_batch)
if __name__ == "__main__":
ljspeech = LJSpeech(LJSPEECH_ROOT)
ljspeech_loader = DataLoader(ljspeech, batch_size=16, shuffle=True, collate_fn=fn)
for i, batch in enumerate(ljspeech_loader):
print(i)
......@@ -12,22 +12,22 @@ and the property:
- n_vocab
"""
from g2p import en
from . import en
# optinoal Japanese frontend
try:
from g2p import jp
from . import jp
except ImportError:
jp = None
try:
from g2p import ko
from . import ko
except ImportError:
ko = None
# if you are going to use the frontend, you need to modify _characters in symbol.py:
# _characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? ' + '¡¿ñáéíóúÁÉÍÓÚÑ'
try:
from g2p import es
from . import es
except ImportError:
es = None
# coding: utf-8
from g2p.text.symbols import symbols
from g2p import text
from g2p.text import sequence_to_text
from ..text.symbols import symbols
from ..text import sequence_to_text
import nltk
from random import random
......@@ -30,7 +29,7 @@ def mix_pronunciation(text, p):
def text_to_sequence(text, p=0.0):
if p >= 0:
text = mix_pronunciation(text, p)
from g2p.text import text_to_sequence
from ..text import text_to_sequence
text = text_to_sequence(text, ["english_cleaners"])
return text
......
# coding: utf-8
from g2p.text.symbols import symbols
from g2p.text import sequence_to_text
from ..text.symbols import symbols
from ..text import sequence_to_text
import nltk
from random import random
......@@ -9,7 +9,7 @@ n_vocab = len(symbols)
def text_to_sequence(text, p=0.0):
from g2p.text import text_to_sequence
from ..text import text_to_sequence
text = text_to_sequence(text, ["basic_cleaners"])
return text
......
import re
from g2p.text import cleaners
from g2p.text.symbols import symbols
from . import cleaners
from .symbols import symbols
# Mappings from symbol to numeric ID and vice versa:
......
from parakeet.data.ljspeech import LJSpeech
from parakeet.data.datacargo import DataCargo
from pathlib import Path
LJSPEECH_ROOT = Path("/Users/chenfeiyu/projects/LJSpeech-1.1")
ljspeech = LJSpeech(LJSPEECH_ROOT)
ljspeech_cargo = DataCargo(ljspeech, batch_size=16, shuffle=True)
for i, batch in enumerate(ljspeech_cargo):
print(i)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册