未验证 提交 3e199781 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #1054 from zh794390558/visual

[asr] using visualdl , jsonlines read manifest
...@@ -339,6 +339,3 @@ You need to prepare an audio file, please confirm the sample rate of the audio i ...@@ -339,6 +339,3 @@ You need to prepare an audio file, please confirm the sample rate of the audio i
```bash ```bash
CUDA_VISIBLE_DEVICES= ./local/test_hub.sh conf/transformer.yaml exp/transformer/checkpoints/avg_20 data/test_audio.wav CUDA_VISIBLE_DEVICES= ./local/test_hub.sh conf/transformer.yaml exp/transformer/checkpoints/avg_20 data/test_audio.wav
``` ```
...@@ -128,8 +128,9 @@ class U2Trainer(Trainer): ...@@ -128,8 +128,9 @@ class U2Trainer(Trainer):
if dist.get_rank() == 0 and self.visualizer: if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy() losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()}) losses_np_v.update({"lr": self.lr_scheduler()})
self.visualizer.add_scalars("step", losses_np_v, for key, val in losses_np_v.items():
self.iteration - 1) self.visualizer.add_scalar(
tag='train/' + key, value=val, step=self.iteration - 1)
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
...@@ -237,9 +238,10 @@ class U2Trainer(Trainer): ...@@ -237,9 +238,10 @@ class U2Trainer(Trainer):
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalars( self.visualizer.add_scalar(
'epoch', {'cv_loss': cv_loss, tag='eval/cv_loss', value=cv_loss, step=self.epoch)
'lr': self.lr_scheduler()}, self.epoch) self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch() self.new_epoch()
......
...@@ -131,8 +131,9 @@ class U2Trainer(Trainer): ...@@ -131,8 +131,9 @@ class U2Trainer(Trainer):
if dist.get_rank() == 0 and self.visualizer: if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy() losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()}) losses_np_v.update({"lr": self.lr_scheduler()})
self.visualizer.add_scalars("step", losses_np_v, for key, val in losses_np_v.items():
self.iteration - 1) self.visualizer.add_scalar(
tag="train/" + key, value=val, step=self.iteration - 1)
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
...@@ -222,9 +223,11 @@ class U2Trainer(Trainer): ...@@ -222,9 +223,11 @@ class U2Trainer(Trainer):
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalars( self.visualizer.add_scalar(
'epoch', {'cv_loss': cv_loss, tag='eval/cv_loss', value=cv_loss, step=self.epoch)
'lr': self.lr_scheduler()}, self.epoch) self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch() self.new_epoch()
......
...@@ -138,8 +138,9 @@ class U2STTrainer(Trainer): ...@@ -138,8 +138,9 @@ class U2STTrainer(Trainer):
if dist.get_rank() == 0 and self.visualizer: if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy() losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()}) losses_np_v.update({"lr": self.lr_scheduler()})
self.visualizer.add_scalars("step", losses_np_v, for key, val in losses_np_v.items():
self.iteration - 1) self.visualizer.add_scalar(
tag="train/" + key, value=val, step=self.iteration - 1)
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
...@@ -235,9 +236,11 @@ class U2STTrainer(Trainer): ...@@ -235,9 +236,11 @@ class U2STTrainer(Trainer):
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalars( self.visualizer.add_scalar(
'epoch', {'cv_loss': cv_loss, tag='eval/cv_loss', value=cv_loss, step=self.epoch)
'lr': self.lr_scheduler()}, self.epoch) self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch() self.new_epoch()
......
...@@ -12,9 +12,10 @@ ...@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains the impulse response augmentation model.""" """Contains the impulse response augmentation model."""
import jsonlines
from paddlespeech.s2t.frontend.audio import AudioSegment from paddlespeech.s2t.frontend.audio import AudioSegment
from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase
from paddlespeech.s2t.frontend.utility import read_manifest
class ImpulseResponseAugmentor(AugmentorBase): class ImpulseResponseAugmentor(AugmentorBase):
...@@ -28,7 +29,8 @@ class ImpulseResponseAugmentor(AugmentorBase): ...@@ -28,7 +29,8 @@ class ImpulseResponseAugmentor(AugmentorBase):
def __init__(self, rng, impulse_manifest_path): def __init__(self, rng, impulse_manifest_path):
self._rng = rng self._rng = rng
self._impulse_manifest = read_manifest(impulse_manifest_path) with jsonlines.open(impulse_manifest_path, 'r') as reader:
self._impulse_manifest = list(reader)
def __call__(self, x, uttid=None, train=True): def __call__(self, x, uttid=None, train=True):
if not train: if not train:
......
...@@ -12,9 +12,10 @@ ...@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains the noise perturb augmentation model.""" """Contains the noise perturb augmentation model."""
import jsonlines
from paddlespeech.s2t.frontend.audio import AudioSegment from paddlespeech.s2t.frontend.audio import AudioSegment
from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase
from paddlespeech.s2t.frontend.utility import read_manifest
class NoisePerturbAugmentor(AugmentorBase): class NoisePerturbAugmentor(AugmentorBase):
...@@ -34,7 +35,8 @@ class NoisePerturbAugmentor(AugmentorBase): ...@@ -34,7 +35,8 @@ class NoisePerturbAugmentor(AugmentorBase):
self._min_snr_dB = min_snr_dB self._min_snr_dB = min_snr_dB
self._max_snr_dB = max_snr_dB self._max_snr_dB = max_snr_dB
self._rng = rng self._rng = rng
self._noise_manifest = read_manifest(manifest_path=noise_manifest_path) with jsonlines.open(noise_manifest_path, 'r') as reader:
self._noise_manifest = list(reader)
def __call__(self, x, uttid=None, train=True): def __call__(self, x, uttid=None, train=True):
if not train: if not train:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Contains feature normalizers.""" """Contains feature normalizers."""
import json import json
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle.io import DataLoader from paddle.io import DataLoader
...@@ -21,7 +22,6 @@ from paddle.io import Dataset ...@@ -21,7 +22,6 @@ from paddle.io import Dataset
from paddlespeech.s2t.frontend.audio import AudioSegment from paddlespeech.s2t.frontend.audio import AudioSegment
from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.frontend.utility import read_manifest
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
__all__ = ["FeatureNormalizer"] __all__ = ["FeatureNormalizer"]
...@@ -61,7 +61,10 @@ class CollateFunc(object): ...@@ -61,7 +61,10 @@ class CollateFunc(object):
class AudioDataset(Dataset): class AudioDataset(Dataset):
def __init__(self, manifest_path, num_samples=-1, rng=None, random_seed=0): def __init__(self, manifest_path, num_samples=-1, rng=None, random_seed=0):
self._rng = rng if rng else np.random.RandomState(random_seed) self._rng = rng if rng else np.random.RandomState(random_seed)
manifest = read_manifest(manifest_path)
with jsonlines.open(manifest_path, 'r') as reader:
manifest = list(reader)
if num_samples == -1: if num_samples == -1:
sampled_manifest = manifest sampled_manifest = manifest
else: else:
......
...@@ -98,7 +98,6 @@ def read_manifest( ...@@ -98,7 +98,6 @@ def read_manifest(
Returns: Returns:
List[dict]: Manifest parsing results. List[dict]: Manifest parsing results.
""" """
manifest = [] manifest = []
with jsonlines.open(manifest_path, 'r') as reader: with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader: for json_data in reader:
......
...@@ -16,10 +16,10 @@ from typing import Dict ...@@ -16,10 +16,10 @@ from typing import Dict
from typing import List from typing import List
from typing import Text from typing import Text
import jsonlines
import numpy as np import numpy as np
from paddle.io import DataLoader from paddle.io import DataLoader
from paddlespeech.s2t.frontend.utility import read_manifest
from paddlespeech.s2t.io.batchfy import make_batchset from paddlespeech.s2t.io.batchfy import make_batchset
from paddlespeech.s2t.io.converter import CustomConverter from paddlespeech.s2t.io.converter import CustomConverter
from paddlespeech.s2t.io.dataset import TransformDataset from paddlespeech.s2t.io.dataset import TransformDataset
...@@ -91,7 +91,9 @@ class BatchDataLoader(): ...@@ -91,7 +91,9 @@ class BatchDataLoader():
self.n_iter_processes = n_iter_processes self.n_iter_processes = n_iter_processes
# read json data # read json data
self.data_json = read_manifest(json_file) with jsonlines.open(json_file, 'r') as reader:
self.data_json = list(reader)
self.feat_dim, self.vocab_size = feat_dim_and_vocab_size( self.feat_dim, self.vocab_size = feat_dim_and_vocab_size(
self.data_json, mode='asr') self.data_json, mode='asr')
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# Modified from wenet(https://github.com/wenet-e2e/wenet) # Modified from wenet(https://github.com/wenet-e2e/wenet)
from typing import Optional from typing import Optional
import jsonlines
from paddle.io import Dataset from paddle.io import Dataset
from yacs.config import CfgNode from yacs.config import CfgNode
...@@ -184,7 +185,8 @@ class AudioDataset(Dataset): ...@@ -184,7 +185,8 @@ class AudioDataset(Dataset):
""" """
assert batch_type in ['static', 'dynamic'] assert batch_type in ['static', 'dynamic']
# read manifest # read manifest
data = read_manifest(data_file) with jsonlines.open(data_file, 'r') as reader:
data = list(reader)
if sort: if sort:
data = sorted(data, key=lambda x: x["feat_shape"][0]) data = sorted(data, key=lambda x: x["feat_shape"][0])
if raw_wav: if raw_wav:
......
...@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False): ...@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
""" """
rng = np.random.RandomState(epoch) rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices) rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch] batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False assert clipped is False
......
...@@ -19,7 +19,7 @@ from pathlib import Path ...@@ -19,7 +19,7 @@ from pathlib import Path
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from tensorboardX import SummaryWriter from visualdl import LogWriter
from paddlespeech.s2t.training.reporter import ObsScope from paddlespeech.s2t.training.reporter import ObsScope
from paddlespeech.s2t.training.reporter import report from paddlespeech.s2t.training.reporter import report
...@@ -309,9 +309,10 @@ class Trainer(): ...@@ -309,9 +309,10 @@ class Trainer():
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalars( self.visualizer.add_scalar(
'epoch', {'cv_loss': cv_loss, tag='eval/cv_loss', value=cv_loss, step=self.epoch)
'lr': self.lr_scheduler()}, self.epoch) self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
# after epoch # after epoch
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
...@@ -427,7 +428,7 @@ class Trainer(): ...@@ -427,7 +428,7 @@ class Trainer():
unexpected behaviors. unexpected behaviors.
""" """
# visualizer # visualizer
visualizer = SummaryWriter(logdir=str(self.visual_dir)) visualizer = LogWriter(logdir=str(self.visual_dir))
self.visualizer = visualizer self.visualizer = visualizer
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
......
...@@ -21,7 +21,7 @@ import wave ...@@ -21,7 +21,7 @@ import wave
from time import gmtime from time import gmtime
from time import strftime from time import strftime
from paddlespeech.s2t.frontend.utility import read_manifest import jsonlines
__all__ = ["socket_send", "warm_up_test", "AsrTCPServer", "AsrRequestHandler"] __all__ = ["socket_send", "warm_up_test", "AsrTCPServer", "AsrRequestHandler"]
...@@ -44,7 +44,8 @@ def warm_up_test(audio_process_handler, ...@@ -44,7 +44,8 @@ def warm_up_test(audio_process_handler,
num_test_cases, num_test_cases,
random_seed=0): random_seed=0):
"""Warming-up test.""" """Warming-up test."""
manifest = read_manifest(manifest_path) with jsonlines.open(manifest_path) as reader:
manifest = list(reader)
rng = random.Random(random_seed) rng = random.Random(random_seed)
samples = rng.sample(manifest, num_test_cases) samples = rng.sample(manifest, num_test_cases)
for idx, sample in enumerate(samples): for idx, sample in enumerate(samples):
......
...@@ -34,7 +34,7 @@ from speechtask.punctuation_restoration.model.lstm import RnnLm ...@@ -34,7 +34,7 @@ from speechtask.punctuation_restoration.model.lstm import RnnLm
from speechtask.punctuation_restoration.utils import layer_tools from speechtask.punctuation_restoration.utils import layer_tools
from speechtask.punctuation_restoration.utils import mp_tools from speechtask.punctuation_restoration.utils import mp_tools
from speechtask.punctuation_restoration.utils.checkpoint import Checkpoint from speechtask.punctuation_restoration.utils.checkpoint import Checkpoint
from tensorboardX import SummaryWriter from visualdl import LogWriter
__all__ = ["Trainer", "Tester"] __all__ = ["Trainer", "Tester"]
...@@ -252,10 +252,10 @@ class Trainer(): ...@@ -252,10 +252,10 @@ class Trainer():
self.logger.info("Epoch {} Val info val_loss {}, F1_score {}". self.logger.info("Epoch {} Val info val_loss {}, F1_score {}".
format(self.epoch, total_loss, F1_score)) format(self.epoch, total_loss, F1_score))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalars("epoch", { self.visualizer.add_scalar(
"total_loss": total_loss, tag='eval/cv_loss', value=cv_loss, step=self.epoch)
"lr": self.lr_scheduler() self.visualizer.add_scalar(
}, self.epoch) tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save( self.save(
tag=self.epoch, infos={"val_loss": total_loss, tag=self.epoch, infos={"val_loss": total_loss,
...@@ -341,7 +341,7 @@ class Trainer(): ...@@ -341,7 +341,7 @@ class Trainer():
unexpected behaviors. unexpected behaviors.
""" """
# visualizer # visualizer
visualizer = SummaryWriter(logdir=str(self.output_dir)) visualizer = LogWriter(logdir=str(self.output_dir))
self.visualizer = visualizer self.visualizer = visualizer
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
......
...@@ -40,7 +40,6 @@ snakeviz ...@@ -40,7 +40,6 @@ snakeviz
soundfile~=0.10 soundfile~=0.10
sox sox
soxbindings soxbindings
tensorboardX
textgrid textgrid
timer timer
tqdm tqdm
......
...@@ -21,9 +21,10 @@ import os ...@@ -21,9 +21,10 @@ import os
import tempfile import tempfile
from collections import Counter from collections import Counter
import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import BLANK from paddlespeech.s2t.frontend.utility import BLANK
from paddlespeech.s2t.frontend.utility import read_manifest
from paddlespeech.s2t.frontend.utility import SOS from paddlespeech.s2t.frontend.utility import SOS
from paddlespeech.s2t.frontend.utility import SPACE from paddlespeech.s2t.frontend.utility import SPACE
from paddlespeech.s2t.frontend.utility import UNK from paddlespeech.s2t.frontend.utility import UNK
...@@ -59,13 +60,21 @@ args = parser.parse_args() ...@@ -59,13 +60,21 @@ args = parser.parse_args()
def count_manifest(counter, text_feature, manifest_path): def count_manifest(counter, text_feature, manifest_path):
manifest_jsons = read_manifest(manifest_path) manifest_jsons = []
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)
for line_json in manifest_jsons: for line_json in manifest_jsons:
line = text_feature.tokenize(line_json['text'], replace_space=False) line = text_feature.tokenize(line_json['text'], replace_space=False)
counter.update(line) counter.update(line)
def dump_text_manifest(fileobj, manifest_path, key='text'): def dump_text_manifest(fileobj, manifest_path, key='text'):
manifest_jsons = read_manifest(manifest_path) manifest_jsons = []
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)
for line_json in manifest_jsons: for line_json in manifest_jsons:
fileobj.write(line_json[key] + "\n") fileobj.write(line_json[key] + "\n")
......
...@@ -17,7 +17,7 @@ import argparse ...@@ -17,7 +17,7 @@ import argparse
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
from paddlespeech.s2t.frontend.utility import read_manifest import jsonlines
key_whitelist = set(['feat', 'text', 'syllable', 'phone']) key_whitelist = set(['feat', 'text', 'syllable', 'phone'])
filename = { filename = {
...@@ -32,7 +32,10 @@ def dump_manifest(manifest_path, output_dir: Union[str, Path]): ...@@ -32,7 +32,10 @@ def dump_manifest(manifest_path, output_dir: Union[str, Path]):
output_dir = Path(output_dir).expanduser() output_dir = Path(output_dir).expanduser()
manifest_path = Path(manifest_path).expanduser() manifest_path = Path(manifest_path).expanduser()
manifest_jsons = read_manifest(manifest_path)
with jsonlines.open(str(manifest_path), 'r') as reader:
manifest_jsons = list(reader)
first_line = manifest_jsons[0] first_line = manifest_jsons[0]
file_map = {} file_map = {}
......
...@@ -17,9 +17,10 @@ import argparse ...@@ -17,9 +17,10 @@ import argparse
import functools import functools
import json import json
import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.frontend.utility import read_manifest
from paddlespeech.s2t.io.utility import feat_type from paddlespeech.s2t.io.utility import feat_type
from paddlespeech.s2t.utils.utility import add_arguments from paddlespeech.s2t.utils.utility import add_arguments
from paddlespeech.s2t.utils.utility import print_arguments from paddlespeech.s2t.utils.utility import print_arguments
...@@ -71,7 +72,9 @@ def main(): ...@@ -71,7 +72,9 @@ def main():
# } # }
count = 0 count = 0
for manifest_path in args.manifest_paths: for manifest_path in args.manifest_paths:
manifest_jsons = read_manifest(manifest_path) with jsonlines.open(str(manifest_path), 'r') as reader:
manifest_jsons = list(reader)
for line_json in manifest_jsons: for line_json in manifest_jsons:
output_json = { output_json = {
"input": [], "input": [],
......
...@@ -17,9 +17,10 @@ import argparse ...@@ -17,9 +17,10 @@ import argparse
import functools import functools
import json import json
import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.frontend.utility import read_manifest
from paddlespeech.s2t.io.utility import feat_type from paddlespeech.s2t.io.utility import feat_type
from paddlespeech.s2t.utils.utility import add_arguments from paddlespeech.s2t.utils.utility import add_arguments
from paddlespeech.s2t.utils.utility import print_arguments from paddlespeech.s2t.utils.utility import print_arguments
...@@ -63,7 +64,8 @@ def main(): ...@@ -63,7 +64,8 @@ def main():
count = 0 count = 0
for manifest_path in args.manifest_paths: for manifest_path in args.manifest_paths:
manifest_jsons = read_manifest(manifest_path) with jsonlines.open(str(manifest_path), 'r') as reader:
manifest_jsons = list(reader)
for line_json in manifest_jsons: for line_json in manifest_jsons:
# text: translation text, text1: transcript text. # text: translation text, text1: transcript text.
# Currently only support joint-vocab, will add separate vocabs setting. # Currently only support joint-vocab, will add separate vocabs setting.
......
...@@ -4,9 +4,10 @@ import argparse ...@@ -4,9 +4,10 @@ import argparse
import functools import functools
from pathlib import Path from pathlib import Path
import jsonlines
from utils.utility import add_arguments from utils.utility import add_arguments
from utils.utility import print_arguments from utils.utility import print_arguments
from utils.utility import read_manifest
def main(args): def main(args):
...@@ -19,7 +20,8 @@ def main(args): ...@@ -19,7 +20,8 @@ def main(args):
dur_scp = outdir / 'duration' dur_scp = outdir / 'duration'
text_scp = outdir / 'text' text_scp = outdir / 'text'
manifest_jsons = read_manifest(args.manifest_path) with jsonlines.open(args.manifest_path, 'r') as reader:
manifest_jsons = list(reader)
with wav_scp.open('w') as fwav, dur_scp.open('w') as fdur, text_scp.open( with wav_scp.open('w') as fwav, dur_scp.open('w') as fdur, text_scp.open(
'w') as ftxt: 'w') as ftxt:
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import hashlib import hashlib
import json
import os import os
import sys import sys
import tarfile import tarfile
...@@ -22,31 +21,10 @@ from typing import Text ...@@ -22,31 +21,10 @@ from typing import Text
__all__ = [ __all__ = [
"check_md5sum", "getfile_insensitive", "download_multi", "download", "check_md5sum", "getfile_insensitive", "download_multi", "download",
"unpack", "unzip", "md5file", "print_arguments", "add_arguments", "unpack", "unzip", "md5file", "print_arguments", "add_arguments",
"read_manifest", "get_commandline_args" "get_commandline_args"
] ]
def read_manifest(manifest_path):
"""Load and parse manifest file.
Args:
manifest_path ([type]): Manifest file to load and parse.
Raises:
IOError: If failed to parse the manifest.
Returns:
List[dict]: Manifest parsing results.
"""
manifest = []
for json_line in open(manifest_path, 'r'):
try:
json_data = json.loads(json_line)
except Exception as e:
raise IOError("Error reading manifest: %s" % str(e))
return manifest
def get_commandline_args(): def get_commandline_args():
extra_chars = [ extra_chars = [
" ", " ",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册