提交 39228864 编写于 作者: H Hui Zhang

format code

上级 d395c2b8
......@@ -339,6 +339,3 @@ You need to prepare an audio file, please confirm the sample rate of the audio i
```bash
CUDA_VISIBLE_DEVICES= ./local/test_hub.sh conf/transformer.yaml exp/transformer/checkpoints/avg_20 data/test_audio.wav
```
......@@ -129,8 +129,8 @@ class U2Trainer(Trainer):
losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()})
for key, val in losses_np_v.items():
self.visualizer.add_scalar(tag='train/'+key, value=val, step=self.iteration-1)
self.visualizer.add_scalar(
tag='train/' + key, value=val, step=self.iteration - 1)
@paddle.no_grad()
def valid(self):
......@@ -238,8 +238,10 @@ class U2Trainer(Trainer):
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch()
......
......@@ -132,7 +132,8 @@ class U2Trainer(Trainer):
losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()})
for key, val in losses_np_v.items():
self.visualizer.add_scalar(tag="train/"+key, value=val, step=self.iteration - 1)
self.visualizer.add_scalar(
tag="train/" + key, value=val, step=self.iteration - 1)
@paddle.no_grad()
def valid(self):
......@@ -222,9 +223,11 @@ class U2Trainer(Trainer):
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch()
......
......@@ -139,7 +139,8 @@ class U2STTrainer(Trainer):
losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()})
for key, val in losses_np_v.items():
self.visualizer.add_scalar(tag="train/"+key, value=val, step=self.iteration - 1)
self.visualizer.add_scalar(
tag="train/" + key, value=val, step=self.iteration - 1)
@paddle.no_grad()
def valid(self):
......@@ -235,9 +236,11 @@ class U2STTrainer(Trainer):
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch()
......
......@@ -13,6 +13,7 @@
# limitations under the License.
"""Contains the impulse response augmentation model."""
import jsonlines
from paddlespeech.s2t.frontend.audio import AudioSegment
from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase
......
......@@ -13,6 +13,7 @@
# limitations under the License.
"""Contains the noise perturb augmentation model."""
import jsonlines
from paddlespeech.s2t.frontend.audio import AudioSegment
from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase
......
......@@ -13,6 +13,7 @@
# limitations under the License.
"""Contains feature normalizers."""
import json
import jsonlines
import numpy as np
import paddle
......@@ -26,7 +27,8 @@ from paddlespeech.s2t.utils.log import Log
__all__ = ["FeatureNormalizer"]
logger = Log(__name__).getlog()
# https://github.com/PaddlePaddle/Paddle/pull/31481
class CollateFunc(object):
def __init__(self, feature_func):
......@@ -62,7 +64,7 @@ class AudioDataset(Dataset):
with jsonlines.open(manifest_path, 'r') as reader:
manifest = list(reader)
if num_samples == -1:
sampled_manifest = manifest
else:
......
......@@ -64,7 +64,7 @@ def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
char_list.append(MASKCTC)
return char_list
def read_manifest(
manifest_path,
max_input_len=float('inf'),
......
......@@ -15,8 +15,8 @@ from typing import Any
from typing import Dict
from typing import List
from typing import Text
import jsonlines
import jsonlines
import numpy as np
from paddle.io import DataLoader
......@@ -93,7 +93,7 @@ class BatchDataLoader():
# read json data
with jsonlines.open(json_file, 'r') as reader:
self.data_json = list(reader)
self.feat_dim, self.vocab_size = feat_dim_and_vocab_size(
self.data_json, mode='asr')
......
......@@ -14,6 +14,7 @@
# Modified from espnet(https://github.com/espnet/espnet)
# Modified from wenet(https://github.com/wenet-e2e/wenet)
from typing import Optional
import jsonlines
from paddle.io import Dataset
from yacs.config import CfgNode
......
......@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False
......
......@@ -309,8 +309,10 @@ class Trainer():
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
# after epoch
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
......
......@@ -20,6 +20,7 @@ import time
import wave
from time import gmtime
from time import strftime
import jsonlines
__all__ = ["socket_send", "warm_up_test", "AsrTCPServer", "AsrRequestHandler"]
......
......@@ -252,8 +252,10 @@ class Trainer():
self.logger.info("Epoch {} Val info val_loss {}, F1_score {}".
format(self.epoch, total_loss, F1_score))
if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(
tag=self.epoch, infos={"val_loss": total_loss,
......
......@@ -19,9 +19,10 @@ import argparse
import functools
import os
import tempfile
import jsonlines
from collections import Counter
import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import BLANK
from paddlespeech.s2t.frontend.utility import SOS
......@@ -63,7 +64,7 @@ def count_manifest(counter, text_feature, manifest_path):
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)
for line_json in manifest_jsons:
line = text_feature.tokenize(line_json['text'], replace_space=False)
counter.update(line)
......@@ -73,7 +74,7 @@ def dump_text_manifest(fileobj, manifest_path, key='text'):
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)
for line_json in manifest_jsons:
fileobj.write(line_json[key] + "\n")
......
......@@ -16,6 +16,7 @@
import argparse
from pathlib import Path
from typing import Union
import jsonlines
key_whitelist = set(['feat', 'text', 'syllable', 'phone'])
......@@ -34,7 +35,7 @@ def dump_manifest(manifest_path, output_dir: Union[str, Path]):
with jsonlines.open(str(manifest_path), 'r') as reader:
manifest_jsons = list(reader)
first_line = manifest_jsons[0]
file_map = {}
......
......@@ -15,9 +15,10 @@
"""format manifest with more metadata."""
import argparse
import functools
import jsonlines
import json
import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.io.utility import feat_type
......@@ -73,7 +74,7 @@ def main():
for manifest_path in args.manifest_paths:
with jsonlines.open(str(manifest_path), 'r') as reader:
manifest_jsons = list(reader)
for line_json in manifest_jsons:
output_json = {
"input": [],
......
......@@ -16,6 +16,7 @@
import argparse
import functools
import json
import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
......
......@@ -3,6 +3,7 @@
import argparse
import functools
from pathlib import Path
import jsonlines
from utils.utility import add_arguments
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import json
import os
import sys
import tarfile
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册